diff --git a/cmake/patches/cutlass/cutlass_3.5.0.patch b/cmake/patches/cutlass/cutlass_3.5.0.patch index 3b829d2f8b2cf..93b8c474af9ed 100644 --- a/cmake/patches/cutlass/cutlass_3.5.0.patch +++ b/cmake/patches/cutlass/cutlass_3.5.0.patch @@ -1,13 +1,64 @@ +diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h +index 4c80f549..34327633 100644 +--- a/examples/41_fused_multi_head_attention/kernel_forward.h ++++ b/examples/41_fused_multi_head_attention/kernel_forward.h +@@ -221,6 +221,8 @@ struct AttentionKernel { + int32_t num_batches = 0; + int32_t num_heads = 0; + ++ bool use_smooth_softmax = false; ++ + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; +@@ -897,7 +899,8 @@ struct AttentionKernel { + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, +- kSupportsBias ? 1.0f : p.scale); ++ kSupportsBias ? 1.0f : p.scale, ++ p.use_smooth_softmax); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % +@@ -1166,7 +1169,8 @@ struct AttentionKernel { + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, +- float scaling) { ++ float scaling, ++ bool use_smooth_softmax) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` +@@ -1257,7 +1261,7 @@ struct AttentionKernel { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, +- [&](int accum_m) { mi_row = mi[accum_m]; }, ++ [&](int accum_m) { mi_row = mi[accum_m];}, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); +@@ -1294,7 +1298,7 @@ struct AttentionKernel { + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } +- s_prime[id] = total_row; ++ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row; + } + } + diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 964d2ff3..b366bc14 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -39,6 +39,7 @@ #include "cutlass/numeric_types.h" - + #include +#include - + #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include @@ -230,8 +231,12 @@ struct inverse_square_root { @@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644 return reinterpret_cast(result); +#else + return half_t::convert((rsqrtf(half_t::convert(lhs)))); -+#endif ++#endif #else return half_t(1.f / std::sqrt(half_t::convert(lhs))); - #endif + #endif \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 8cd6296bf61a7..95a18621c05d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -345,17 +345,18 @@ Status EfficientAttention( p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + parameters.batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + 2 * parameters.batch_size + 1)); + p.use_smooth_softmax = false; + + if (nullptr == data.mask_index) { + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + } else { + p.seqlen_k_ptr = const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size; + p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1; + } + p.query = data.q; p.key = data.k; p.value = data.v; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a5de20e44be1a..222c641883a90 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -214,6 +214,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; } + + p.use_smooth_softmax = params.use_smooth_softmax; } auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 08a562a12b844..81e70dab4e683 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -27,6 +27,7 @@ struct MemoryEfficientAttentionParams { bool causal; // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. bool is_attn_bias_batched; + bool use_smooth_softmax; float scale; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 48ecfd7304f4b..1f378a184ab9b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -153,7 +153,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION int sm = (device_prop.major * 10) + device_prop.minor; bool use_memory_efficient_attention = - !use_smooth_softmax_ && !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index beac8aeebfb39..356f723902da7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -678,8 +678,8 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, - scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); // if (parameters.left_padding && parameters.is_prompt) { @@ -844,6 +844,7 @@ Status EfficientAttention( : nullptr; p.stream = stream; p.has_custom_right_padding = true; + p.use_smooth_softmax = parameters.use_smooth_softmax; run_memory_efficient_attention(p); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 2521cd49b5482..e4a5afd528a9a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; + p.use_smooth_softmax = false; p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; p.seqlen_k_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e5a4c54f48903..3c25f9146edfd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -693,6 +693,7 @@ Status FusedAttentionCutlass( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; + p.use_smooth_softmax = false; p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; p.seqlen_k_ptr = nullptr; diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 17b9276a882eb..13bf51f74389a 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -2219,7 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - use_smooth_softmax=False, + use_smooth_softmax=True, ) @parameterized.expand(gqa_no_past_flash_attention_test_cases()) @@ -2263,7 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - use_smooth_softmax=False, + use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config,