Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] enable causal in MultiHeadAttention #21852

Merged
merged 2 commits into from
Aug 26, 2024
Merged

[CUDA] enable causal in MultiHeadAttention #21852

merged 2 commits into from
Aug 26, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Aug 25, 2024

Description

Enable causal in MultiHeadAttention cuda operator.

All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H and QKV_BSN3H) supports causal in CUDA for now. Internally, casual will be dispatch to flash attention, efficient attention or unfused attention kernel.

Motivation and Context

Currently, MultiHeadAttention has causal enabled in CPU ep, but not in CUDA ep. It could cause issues in onnx conversion, like some model can run in CPU but not in CUDA. Enable causal in CUDA will reduce the difference of support matrix of CPU/CUDA.

@tianleiwu tianleiwu marked this pull request as draft August 25, 2024 18:32
@tianleiwu tianleiwu marked this pull request as ready for review August 26, 2024 06:32
@kunal-vaishnavi
Copy link
Contributor

Thank you for adding this! With this PR, can we now pass a 2D attention mask of shape (batch_size, total_sequence_length) directly to MultiHeadAttention as the key_padding_mask input instead of reformatting the attention mask from 2D to 4D causal? If so, can we use the key_padding_mask input this way for the CPU, CUDA, and DML EPs?

For context, the reformatting from 2D to 4D causal was done in the model builder because not all EPs implement causal attention masking with the key_padding_mask input. This caused discrepancies in the outputs from the ONNX model depending on the EP used.

@tianleiwu
Copy link
Contributor Author

tianleiwu commented Aug 26, 2024

Thank you for adding this! With this PR, can we now pass a 2D attention mask of shape (batch_size, total_sequence_length) directly to MultiHeadAttention as the key_padding_mask input instead of reformatting the attention mask from 2D to 4D causal? If so, can we use the key_padding_mask input this way for the CPU, CUDA, and DML EPs?

For context, the reformatting from 2D to 4D causal was done in the model builder because not all EPs implement causal attention masking with the key_padding_mask input. This caused discrepancies in the outputs from the ONNX model depending on the EP used.

Yes, you can use 2D mask directly (and no need to convert to 4D attention bias). However, 2D mask is for unfused kernel for now. In the future, we might add 2D to 1D conversion to help huggingface model, however that is not best choice for performance since such conversion need extra cuda kernels.

I would suggest to use 1D mask of shape [B] (total seq lengths of each batch assuming right side padding, i.e. the reduce sum of 2D mask) if you want to get benefit of flash attention.

Currently, memory efficient attention need another 1D format with shape [3B+2]. In the future, we will add conversion from 1d mask of shape [B] to tensors compatible with memory efficient attention.

@tianleiwu tianleiwu merged commit ad38212 into main Aug 26, 2024
97 checks passed
@tianleiwu tianleiwu deleted the tlwu/mha_causal branch August 26, 2024 20:34
tianleiwu added a commit that referenced this pull request Aug 29, 2024
Enable causal in MultiHeadAttention cuda operator.

All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H
and QKV_BSN3H) supports causal for now. Internally, casual will be
dispatch to flash attention, efficient attention or unfused attention
kernel.

Currently, MultiHeadAttention has causal enabled in CPU ep, but not in
CUDA ep. It could cause issues in onnx conversion, like some model can
run in CPU but not in CUDA. Enable causal in CUDA will reduce the
difference of support matrix of CPU/CUDA.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants