Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Jan 22, 2024
1 parent dba1e7e commit e7863b3
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
17 changes: 8 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ void set_params_fprop(Flash_fwd_params& params,
size_t batch_size,
size_t seqlen_q,
size_t seqlen_k,
size_t seqlen_k_max,
size_t seqlen_q_rounded,
size_t seqlen_k_rounded,
size_t num_heads,
Expand Down Expand Up @@ -63,16 +62,16 @@ void set_params_fprop(Flash_fwd_params& params,
params.k_row_stride = head_size;
params.v_row_stride = head_size;
params.q_head_stride = head_size;
params.k_head_stride = seqlen_k_max * head_size;
params.v_head_stride = seqlen_k_max * head_size;
params.k_head_stride = seqlen_k * head_size;
params.v_head_stride = seqlen_k * head_size;
params.o_row_stride = num_heads * head_size;
params.o_head_stride = head_size;
}

if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.k_batch_stride = seqlen_k_max * num_heads_k * head_size; // stride(0)
params.v_batch_stride = seqlen_k_max * num_heads_k * head_size; // stride(0)
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
} else {
params.q_batch_stride = 0;
Expand Down Expand Up @@ -256,7 +255,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k, seqlen_k,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
Expand Down Expand Up @@ -320,7 +319,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k, max_seqlen_k,
max_seqlen_q, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
Expand Down Expand Up @@ -372,7 +371,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int seqlen_k_max,
const float softmax_scale,
bool is_causal,
bool is_bf16,
Expand All @@ -388,10 +386,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// In kv-cache case, seqlen_k_max as kv sequence length
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k, seqlen_k_max,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int seqlen_k_max,
const float softmax_scale,
bool is_causal,
bool is_bf16,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class GroupQueryAttention final : public CudaKernel {
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
static constexpr int kZerosCount = 256;
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)

Check warning on line 33 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.h#L33

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.h:33:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 33 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.h#L33

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.h:33:  At least two spaces is best between code and comments  [whitespace/comments] [2]
IAllocatorUniquePtr<int> zeros_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ Status CheckInputs(const Tensor* query,
const auto& cos_dims = cos_cache->Shape().GetDims();
const auto& sin_dims = sin_cache->Shape().GetDims();

if (head_size % 16 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size shall be a multiple of 16. Got head_size % 16 == ",
head_size % 16);
}
if (cos_dims[0] != present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 must be of present_sequence_length.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ Status FlashAttention(
if (batch_size <= parameters.zeros_count) {
seqlens_k = parameters.zero_ptr;
} else {
// Launch kernel to copy seqlen
// Launch kernel to create larger seqlen tensor when batch_size > 256
constexpr int thr_per_blk = 256;
int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk;
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k_total, 0, batch_size);
Expand All @@ -530,7 +530,7 @@ Status FlashAttention(
device_prop, stream, query, present_key, present_value, key, value, data.output,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.seqlen_present_kv_cache,
parameters.seqlen_present_kv_cache, kv_sequence_length,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
parameters.is_packed_qkv));
Expand Down

0 comments on commit e7863b3

Please sign in to comment.