Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Smooth Softmax in GroupQueryAttention #21867

Merged
merged 9 commits into from
Aug 27, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Aug 27, 2024

Description

This implements a modified softmax that adding one to the denominator. Similar idea was implemented as add_zero_attn in torch.nn.MultiheadAttention about 6 years ago, and later in softmax_with_extra_logit in flaxformer about 3 years ago. Also recently the formula was analyzed in Attention Is Off By One by Evan Miller.

Background

Softmax (formula 1) is like the following:

$$y_{i} = \frac{exp(x_{i})}{\sum_{j} exp(x_{j})}$$

After applying softmax, each element will be in the range of $[0, 1]$, and the elements will add up to 1, so that they can be interpreted as probabilities.

In transformers model like BERT or LLAMA, softmax is used to normalize batch_size x num_heads x query_sequence_length rows, each row has size of key_sequence_length (total tokens in key).

For such application, softmax has two potential issues:

  • When all elements are -inf (for example, a whole row is masked when a query token is padding), the result is not a number (NaN) since exp(-inf)=0 and divided-by-zero is encountered in the above formula.
  • Why do we need normalize in a way that each row are treated as equal important (each row has sum equals to1)? In language model, some words in a query may carry more weight because they provide key information that determines the context or meaning of the query.

For example, consider the query: "Find the best Italian restaurant near me."
In this sentence:

  • "Italian" and "restaurant" are the most important words because they define what the user is specifically looking for.
  • "Best" is also significant as it indicates the user is seeking a high-quality option.
  • "Near me" adds location-based context to the search.
  • Words like "find" and "the" are less important in this context because they don't add specific content to the query; they are more functional in nature.

Thus, in processing this query, a language model might give more attention or weight to the words "Italian," "restaurant," and "best" to understand and respond accurately.

Smooth Softmax (formula 2) is a modified version that introduces a smooth factor like the following:

$$s_{i} = \frac{exp(x_{i})}{1+ \sum_{j} exp(x_{j})}$$

A straight-forward explanation of such formula is that there is an additional virtual element equal to zero, then exp(0) becomes the extra 1 in denominator.

This formula could tackle the above two issues:

  • It could handle the special case that all elements are -inf: the result $s_{i}$ is 0 for every element in such case.
  • Sum of all elements $\sum_{i}{s_{i}} = \frac{\sum_{i}{exp(x_{i})}}{1+ \sum_{j} exp(x_{j})}$ is in the range of [0, 1), so that we can train the model to assign different importance to different rows. For a query word or a head that is not important, the whole row might be zeros.

Since exponential is prone to overflow or underflow, to get stable result, formula 3 can be used:

$$s_{i} = \frac{exp(x_{i} + c)}{exp(c)+ \sum_{j} exp(x_{j} +c)}$$

c can be any value in theory. In practical, choice of constant c shall avoid $exp(c)$ and $exp(x_{i} +c)$ overflow (or underflow) at the same time. A reasonable choice is like formula 4:

$$c=-\max_{j} \{ x_j \}$$

or apply a constraint that c <=0 like the following formula 5:

$$c=-\max(0, \max_{j} \{ x_j \})$$

The latter one (formula 5) ensures that $s_{i}$ will fallback to formula 2 when all elements are negative.

For CPU provider, smooth softmax is implemented in MLAS. CPU implementation uses formula 5.

@wangyems implemented the smooth softmax in flash attention for CUDA, which requires Ampere or newer GPU. The implementation of smooth softmax in flash attention uses formula 4.

@tianleiwu tianleiwu requested a review from a team as a code owner August 27, 2024 01:02
@yufenglee
Copy link
Member

For this "When $x_{i}$ is much smaller than 0, $exp(x_{i})$ will be close to 1. When all elements are much smaller than 0, the result will be close to 1 / (1 + N).", is it a typo "When $x_{i}$ is much smaller than 0"? It should be $x_{i}$ is much close to 0.

@tianleiwu
Copy link
Contributor Author

For this "When x i is much smaller than 0, e x p ( x i ) will be close to 1. When all elements are much smaller than 0, the result will be close to 1 / (1 + N).", is it a typo "When x i is much smaller than 0"? It should be x i is much close to 0.

When x_i is much smaller than 0, exp(x_i) will be close to 0. When all elements are much smaller than 0, the result will be close to 0.
When x_i is close to 0, exp(x_i) will be close to 1. When all elements are close to 0, the result will be close to 1 / (1+N).

@tianleiwu tianleiwu merged commit 6e57576 into main Aug 27, 2024
97 checks passed
@tianleiwu tianleiwu deleted the smooth_softmax_gqa_flash_and_cpu branch August 27, 2024 06:13
wangyems added a commit that referenced this pull request Aug 28, 2024
### Description
<!-- Describe your changes. -->
refer to #21867


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
wangyems added a commit that referenced this pull request Aug 28, 2024
<!-- Describe your changes. -->
refer to #21867

<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
tianleiwu added a commit that referenced this pull request Aug 29, 2024
Softmax (formula 1) is like the following:
```math
y_{i} = \frac{exp(x_{i})}{\sum_{i} exp(x_{i})}
```
After applying softmax, each element will be in the range of $(0, 1)$,
and the elements will add up to 1, so that they can be interpreted as
probabilities.

However, in language model, softmax has two issues:
* When all elements are -inf (for example, a whole row is masked when a
query token is padding), the result is not defined since exp(-inf)=0 and
divided-by-zero is encountered in the above formula.
* Why do we need normalize in a way that each query word are treated as
equal important (each row has sum equals to1)?

**Smooth Softmax** (formula 2) is a modified version that introduces a
smooth factor like the following:
```math
s_{i} = \frac{exp(x_{i})}{1+ \sum_{i} exp(x_{i})}
```

This formula could tackle the above two issues:
* It could handle the special case that all elements are -inf: the
result $s_{i}$ is 0 for every element in such case.
* Sum of all elements $\sum_{i}{s_{i}} = \frac{\sum_{i}{exp(x_{i})}}{1+
\sum_{i} exp(x_{i})}$ is in the range of (0, 1), so that we can train
the model to assign different importance to different query words.

Since exponential is prone to overflow or underflow, to get stable
result, formula 3 can be used:
```math
s_{i} = \frac{exp(x_{i} + c)}{exp(c)+ \sum_{i} exp(x_{i} +c)}
```
c can be any value in theory. In practical, choice of constant c shall
avoid $exp(c)$ and $exp(x_{i} +c)$ overflow (or underflow) at the same
time. A reasonable choice is like formula 4:
```math
c=-\max_{i} \{ x_i \}
```
or  apply a constraint that c <=0 like the following formula 5:

```math
c=-\max(0, \max_{i} \{ x_i \})
```
The latter one (formula 5) ensures that $s_{i}$ will fallback to formula
2 when all elements are negative.

For CPU provider, smooth softmax is implemented in MLAS. CPU
implementation uses formula 5.

@wangyems implemented the smooth softmax in flash attention for CUDA,
which requires Ampere or newer GPU. The implementation of smooth softmax
in flash attention uses formula 4.

---------

Co-authored-by: Ye Wang
tianleiwu pushed a commit that referenced this pull request Aug 29, 2024
<!-- Describe your changes. -->
refer to #21867

<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants