-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Smooth Softmax in GroupQueryAttention #21867
Conversation
For this "When |
When x_i is much smaller than 0, exp(x_i) will be close to 0. When all elements are much smaller than 0, the result will be close to 0. |
### Description <!-- Describe your changes. --> refer to #21867 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <[email protected]>
<!-- Describe your changes. --> refer to #21867 <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <[email protected]>
Softmax (formula 1) is like the following: ```math y_{i} = \frac{exp(x_{i})}{\sum_{i} exp(x_{i})} ``` After applying softmax, each element will be in the range of $(0, 1)$, and the elements will add up to 1, so that they can be interpreted as probabilities. However, in language model, softmax has two issues: * When all elements are -inf (for example, a whole row is masked when a query token is padding), the result is not defined since exp(-inf)=0 and divided-by-zero is encountered in the above formula. * Why do we need normalize in a way that each query word are treated as equal important (each row has sum equals to1)? **Smooth Softmax** (formula 2) is a modified version that introduces a smooth factor like the following: ```math s_{i} = \frac{exp(x_{i})}{1+ \sum_{i} exp(x_{i})} ``` This formula could tackle the above two issues: * It could handle the special case that all elements are -inf: the result $s_{i}$ is 0 for every element in such case. * Sum of all elements $\sum_{i}{s_{i}} = \frac{\sum_{i}{exp(x_{i})}}{1+ \sum_{i} exp(x_{i})}$ is in the range of (0, 1), so that we can train the model to assign different importance to different query words. Since exponential is prone to overflow or underflow, to get stable result, formula 3 can be used: ```math s_{i} = \frac{exp(x_{i} + c)}{exp(c)+ \sum_{i} exp(x_{i} +c)} ``` c can be any value in theory. In practical, choice of constant c shall avoid $exp(c)$ and $exp(x_{i} +c)$ overflow (or underflow) at the same time. A reasonable choice is like formula 4: ```math c=-\max_{i} \{ x_i \} ``` or apply a constraint that c <=0 like the following formula 5: ```math c=-\max(0, \max_{i} \{ x_i \}) ``` The latter one (formula 5) ensures that $s_{i}$ will fallback to formula 2 when all elements are negative. For CPU provider, smooth softmax is implemented in MLAS. CPU implementation uses formula 5. @wangyems implemented the smooth softmax in flash attention for CUDA, which requires Ampere or newer GPU. The implementation of smooth softmax in flash attention uses formula 4. --------- Co-authored-by: Ye Wang
<!-- Describe your changes. --> refer to #21867 <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <[email protected]>
Description
This implements a modified softmax that adding one to the denominator. Similar idea was implemented as add_zero_attn in torch.nn.MultiheadAttention about 6 years ago, and later in softmax_with_extra_logit in flaxformer about 3 years ago. Also recently the formula was analyzed in Attention Is Off By One by Evan Miller.
Background
Softmax (formula 1) is like the following:
After applying softmax, each element will be in the range of$[0, 1]$ , and the elements will add up to 1, so that they can be interpreted as probabilities.
In transformers model like BERT or LLAMA, softmax is used to normalize batch_size x num_heads x query_sequence_length rows, each row has size of key_sequence_length (total tokens in key).
For such application, softmax has two potential issues:
For example, consider the query: "Find the best Italian restaurant near me."
In this sentence:
Thus, in processing this query, a language model might give more attention or weight to the words "Italian," "restaurant," and "best" to understand and respond accurately.
Smooth Softmax (formula 2) is a modified version that introduces a smooth factor like the following:
A straight-forward explanation of such formula is that there is an additional virtual element equal to zero, then exp(0) becomes the extra 1 in denominator.
This formula could tackle the above two issues:
Since exponential is prone to overflow or underflow, to get stable result, formula 3 can be used:
c can be any value in theory. In practical, choice of constant c shall avoid$exp(c)$ and $exp(x_{i} +c)$ overflow (or underflow) at the same time. A reasonable choice is like formula 4:
or apply a constraint that c <=0 like the following formula 5:
The latter one (formula 5) ensures that$s_{i}$ will fallback to formula 2 when all elements are negative.
For CPU provider, smooth softmax is implemented in MLAS. CPU implementation uses formula 5.
@wangyems implemented the smooth softmax in flash attention for CUDA, which requires Ampere or newer GPU. The implementation of smooth softmax in flash attention uses formula 4.