Skip to content

Commit

Permalink
xe: sdpa: Fix alignment for the K and V tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
umar456 committed Dec 19, 2024
1 parent b516214 commit 89ddbe7
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/gpu/intel/ocl/micro_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {

problem_kq.B.layout = MatrixLayout::Pr;
problem_kq.C.layout = MatrixLayout::T;
problem_kq.A.setAlignment(alignmentForLD(d->head_size() * problem.Ta));
const memory_desc_wrapper key_mdw(key_md());
auto ldk = static_cast<int>(
gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size());
problem_kq.A.setAlignment(alignmentForLD(ldk));
problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM
problem_kq.B.crosspack = 2;
problem_kq.B.tileR = into<uint16_t>(d_max());
Expand Down Expand Up @@ -331,7 +334,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {

problem_vs.B.layout = MatrixLayout::Pr;
problem_vs.C.layout = MatrixLayout::N;
problem_vs.A.setAlignment(alignmentForLD(d->head_size() * problem.Ta));
const memory_desc_wrapper val_mdw(val_md());
auto ldv = static_cast<int>(
gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size());
problem_vs.A.setAlignment(alignmentForLD(ldv));
problem_vs.B.setAlignment(64); // S is packed in SLM
problem_vs.B.crosspack = 16;
sizes.m = d->values();
Expand Down

0 comments on commit 89ddbe7

Please sign in to comment.