Skip to content

Commit

Permalink
Support Smooth Softmax in fmha (#21885)
Browse files Browse the repository at this point in the history
<!-- Describe your changes. -->
refer to #21867

<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
2 people authored and tianleiwu committed Aug 29, 2024
1 parent 1537105 commit 37f896d
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 20 deletions.
59 changes: 55 additions & 4 deletions cmake/patches/cutlass/cutlass_3.5.0.patch
Original file line number Diff line number Diff line change
@@ -1,13 +1,64 @@
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
index 4c80f549..34327633 100644
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
@@ -221,6 +221,8 @@ struct AttentionKernel {
int32_t num_batches = 0;
int32_t num_heads = 0;

+ bool use_smooth_softmax = false;
+
// dropout
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
@@ -897,7 +899,8 @@ struct AttentionKernel {
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
- kSupportsBias ? 1.0f : p.scale);
+ kSupportsBias ? 1.0f : p.scale,
+ p.use_smooth_softmax);

// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
- float scaling) {
+ float scaling,
+ bool use_smooth_softmax) {
/* Iterates on the accumulator and corresponding position on result matrix

(1) Update `mi[r]` to the max value of the row `r`
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
- [&](int accum_m) { mi_row = mi[accum_m]; },
+ [&](int accum_m) { mi_row = mi[accum_m];},
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
- s_prime[id] = total_row;
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
}
}

diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
#include "cutlass/numeric_types.h"

#include <cuda_runtime.h>
+#include <cuda_fp16.h>

#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
Expand All @@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
return reinterpret_cast<half_t const &>(result);
+#else
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
+#endif
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
#endif
23 changes: 12 additions & 11 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,18 @@ Status EfficientAttention(
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
p.scale = scale;
p.seqlen_k_ptr = nullptr == data.mask_index
? nullptr
: const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index));
p.seqstart_q_ptr = nullptr == data.mask_index
? nullptr
: const_cast<int32_t*>(reinterpret_cast<const int32_t*>(
data.mask_index + parameters.batch_size));
p.seqstart_k_ptr = nullptr == data.mask_index
? nullptr
: const_cast<int32_t*>(reinterpret_cast<const int32_t*>(
data.mask_index + 2 * parameters.batch_size + 1));
p.use_smooth_softmax = false;

if (nullptr == data.mask_index) {
p.seqlen_k_ptr = nullptr;
p.seqstart_q_ptr = nullptr;
p.seqstart_k_ptr = nullptr;
} else {
p.seqlen_k_ptr = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index));
p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size;
p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1;
}

p.query = data.q;
p.key = data.k;
p.value = data.v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
}

p.use_smooth_softmax = params.use_smooth_softmax;
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct MemoryEfficientAttentionParams {
bool causal;
// The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models.
bool is_attn_bias_batched;
bool use_smooth_softmax;

float scale;

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_smooth_softmax_ &&
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,8 @@ Status FlashAttention(
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));

// if (parameters.left_padding && parameters.is_prompt) {
Expand Down Expand Up @@ -844,6 +844,7 @@ Status EfficientAttention(
: nullptr;
p.stream = stream;
p.has_custom_right_padding = true;
p.use_smooth_softmax = parameters.use_smooth_softmax;
run_memory_efficient_attention(p);

DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass(
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
p.use_smooth_softmax = false;
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
p.seqlen_k_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ Status FusedAttentionCutlass(
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
p.use_smooth_softmax = false;
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
p.seqlen_k_ptr = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/python/transformers/test_flash_attn_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)

@parameterized.expand(gqa_no_past_flash_attention_test_cases())
Expand Down Expand Up @@ -2263,7 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,
Expand Down

0 comments on commit 37f896d

Please sign in to comment.