Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA Rotary and Packed QKV with Flash #18906

Merged
merged 58 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ef82d4d
squash merge
aciddelgado Oct 12, 2023
54f2526
add unit test and fix build
aciddelgado Oct 12, 2023
e2e1157
undo work in attention_impl file
aciddelgado Oct 17, 2023
415440b
reduce tests and change default behavior for past-kv is nullptr
aciddelgado Oct 19, 2023
0e84d2d
test compatibility w/ no cuda
aciddelgado Oct 19, 2023
9ba6963
exclude from amd
aciddelgado Oct 19, 2023
6573133
fix test script
aciddelgado Oct 20, 2023
bd79b6d
work on local attention flash
aciddelgado Oct 23, 2023
fb8e386
vscode idk
aciddelgado Oct 23, 2023
a87f211
make kernels more efficient and make present output required
aciddelgado Oct 24, 2023
16bda28
Merge branch 'main' into aciddelgado/gqa_memeff_v2
aciddelgado Oct 24, 2023
08f553d
merge main and memeff changes
aciddelgado Oct 24, 2023
eb12522
address comments
aciddelgado Oct 25, 2023
6c6aead
update ContribOperators.md
aciddelgado Oct 25, 2023
6e540e3
merge main
aciddelgado Oct 25, 2023
b3a9d0f
Merge branch 'aciddelgado/gqa_memeff_v2' into aciddelgado/gqa_local
aciddelgado Oct 25, 2023
3d7f3bf
local working with flash not memeff
aciddelgado Oct 26, 2023
db307f3
clarify input and output formats memory efficient attention
aciddelgado Oct 27, 2023
e7a50ee
max sequence length for memory efficient attention
aciddelgado Oct 27, 2023
bc1cf0a
clang and fix test file
aciddelgado Oct 27, 2023
44ca857
undo clang on unrelated files
aciddelgado Oct 27, 2023
f08495f
Merge branch 'main' into aciddelgado/gqa_memeff_v2
aciddelgado Oct 30, 2023
660c8fd
check value and key inputs
aciddelgado Oct 31, 2023
b9c4d15
key and value dont check for nullptr since they are required
aciddelgado Oct 31, 2023
afe22a4
fix up test script
aciddelgado Oct 31, 2023
0125678
fix packedmha, clean test, merge gqa_memeff_v2 branch changes
aciddelgado Nov 1, 2023
353c4f5
Merge branch 'main' into aciddelgado/gqa_local
aciddelgado Nov 8, 2023
94b5efb
local working with recent changes
aciddelgado Nov 8, 2023
ddb7a66
no local w memeff
aciddelgado Nov 9, 2023
d33c69b
undo unnecessary changes
aciddelgado Nov 9, 2023
26ca6d5
undo change symbolic_shape_infer.py
aciddelgado Nov 9, 2023
163a81e
fix pipeline
aciddelgado Nov 9, 2023
4326c8d
docs
aciddelgado Nov 10, 2023
b0a1006
make prompt to use the kv new
yufenglee Nov 12, 2023
23c22a3
clean up
aciddelgado Nov 16, 2023
e97a6fd
update documentation
aciddelgado Nov 16, 2023
93cb019
Merge branch 'main' into aciddelgado/gqa_local
aciddelgado Nov 16, 2023
3c332f0
start work rotary
aciddelgado Nov 20, 2023
b26b4cb
Merge branch 'main' into aciddelgado/gqa_rotary
aciddelgado Nov 20, 2023
082f347
rotary work
aciddelgado Dec 13, 2023
ae34d0c
Merge branch 'main' into aciddelgado/gqa_rotary
aciddelgado Dec 13, 2023
791bbc3
Merge branch 'yufeng/gqa_opt' into aciddelgado/gqa_rotary
aciddelgado Dec 13, 2023
6e6ad2c
rotary fully implemented
aciddelgado Dec 15, 2023
f67316d
packed working
aciddelgado Dec 20, 2023
8084585
run formatters
aciddelgado Dec 21, 2023
94afb76
enable gpu linux pipeline transformers tests
aciddelgado Dec 22, 2023
40cfc26
docs and pipeline test
aciddelgado Dec 22, 2023
a821588
Merge branch 'main' into aciddelgado/gqa_rotary_packed
aciddelgado Jan 11, 2024
b473283
test
aciddelgado Jan 11, 2024
152d920
run pipeline
aciddelgado Jan 11, 2024
f271a74
retrigger checks
aciddelgado Jan 11, 2024
ba13a3f
conflict and requirements change
aciddelgado Jan 12, 2024
723637e
disable transformers test
aciddelgado Jan 12, 2024
87533ef
add todo and format
aciddelgado Jan 12, 2024
615d500
fix lint issue
aciddelgado Jan 12, 2024
dba1e7e
merge conflict
aciddelgado Jan 19, 2024
e7863b3
address comments
aciddelgado Jan 22, 2024
5b55424
lintrunner
aciddelgado Jan 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>do_rotary</tt> : int</dt>
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
</dl>

#### Inputs
#### Inputs (7 - 9)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
<dt><tt>value</tt> : T</dt>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
Expand All @@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of total sequence length (past + new).</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters {
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;
};

namespace attention {
Expand Down
70 changes: 44 additions & 26 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
size_t batch_size,
size_t seqlen_q,
size_t seqlen_k,
size_t seqlen_k_max,
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
size_t seqlen_q_rounded,
size_t seqlen_k_rounded,
size_t num_heads,
Expand Down Expand Up @@ -62,17 +63,17 @@
params.k_row_stride = head_size;
params.v_row_stride = head_size;
params.q_head_stride = head_size;
params.k_head_stride = seqlen_k * head_size;
params.v_head_stride = seqlen_k * head_size;
params.k_head_stride = seqlen_k_max * head_size;
params.v_head_stride = seqlen_k_max * head_size;
params.o_row_stride = num_heads * head_size;
params.o_head_stride = head_size;
}

if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.k_batch_stride = seqlen_k_max * num_heads_k * head_size; // stride(0)
params.v_batch_stride = seqlen_k_max * num_heads_k * head_size; // stride(0)
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
} else {
params.q_batch_stride = 0;
params.k_batch_stride = 0;
Expand Down Expand Up @@ -255,7 +256,7 @@
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q, seqlen_k, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
Expand Down Expand Up @@ -319,7 +320,7 @@
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k,
max_seqlen_q, max_seqlen_k, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
Expand Down Expand Up @@ -355,32 +356,33 @@
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size

Check warning on line 359 in onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc#L359

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc:359:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size

Check warning on line 360 in onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc#L360

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc:360:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
int head_size,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int seqlen_k_max,
const float softmax_scale,
bool is_causal,
bool is_bf16,
bool past_bsnh, // otherwise bnsh
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size) {
// if (seqlen_q == 1) {
// is_causal = false;
// } // causal=true is the same as causal=false in this case

int local_window_size,
bool is_rotary_interleaved,
bool is_packed_qkv) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
Expand All @@ -389,7 +391,7 @@
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q, seqlen_k, seqlen_k_max,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
Expand All @@ -406,15 +408,24 @@
is_causal ? 0 : -1);
params.dprops = &dprops;

if (k != nullptr && v != nullptr) {
if (k_new != nullptr && v_new != nullptr) {
params.seqlen_knew = seqlen_k_new;
params.knew_ptr = k;
params.vnew_ptr = v;
params.knew_ptr = k_new;
params.vnew_ptr = v_new;
// All stride are in elements, not bytes.
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
if (is_packed_qkv) {
params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
} else {
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
}
params.knew_head_stride = head_size;
params.vnew_head_stride = head_size;
} else {
Expand All @@ -434,6 +445,13 @@
params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
}

if (rotary_cos != nullptr) {
params.rotary_cos_ptr = rotary_cos;
params.rotary_sin_ptr = rotary_sin;
params.is_rotary_interleaved = is_rotary_interleaved;
params.rotary_dim = (head_size / 16) * 16;
}

params.num_splits = num_splits;
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
params.softmax_lseaccum_ptr = softmax_lse_accum;
Expand All @@ -444,7 +462,7 @@
}

// Only split kernel supports appending to KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,26 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
int head_size,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int seqlen_k_max,
const float softmax_scale,
bool is_causal,
bool is_bf16,
bool past_bsnh, // otherwise bnsh
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1);
int local_window_size = -1,
bool is_rotary_interleaved = false,
bool is_packed_qkv = false);

size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

Expand Down
26 changes: 24 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
kv_num_heads_ = static_cast<int>(kv_num_heads);
is_past_bsnh_ = false; // info.GetAttrOrDefault<int64_t>("is_past_bsnh", 1) == 1;
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

#if USE_FLASH_ATTENTION
Expand All @@ -62,6 +64,9 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_memory_efficient_attention_ = true;
#endif
if (!disable_flash_attention_) {
zeros_ = this->GetScratchBuffer<int>(kZerosCount, nullptr);
}
}

template <typename T>
Expand All @@ -73,6 +78,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* seqlens_k = context->Input<Tensor>(5);
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);

auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
Expand All @@ -84,6 +91,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
value,
past_key,
past_value,
cos_cache,
sin_cache,
&parameters,
num_heads_,
kv_num_heads_,
Expand All @@ -93,7 +102,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
scale_,
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
// parameters.left_padding = left_padding_;
int sequence_length = parameters.sequence_length;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;

TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
Expand Down Expand Up @@ -139,6 +154,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
do_rotary_ == false &&
key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
Expand Down Expand Up @@ -182,8 +199,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* present_value = context->Output(2, present_shape);

data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = reinterpret_cast<const CudaT*>(value->Data<T>());
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
Expand Down Expand Up @@ -229,6 +246,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
// Rotary
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel {
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
int local_window_size_;
bool is_unidirectional_;
bool is_past_bsnh_;
bool do_rotary_;
bool rotary_interleaved_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
static constexpr int kZerosCount = 256;
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
IAllocatorUniquePtr<int> zeros_;
};

} // namespace cuda
Expand Down
Loading
Loading