Skip to content

Commit

Permalink
gpu: jit: gemm: fix out-of-bounds m/n cooperative prefetches
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad authored and vpirogov committed Apr 8, 2023
1 parent f27dedb commit 2b8f6b1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 43 deletions.
83 changes: 42 additions & 41 deletions src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1375,10 +1375,10 @@ void gemm_kernel_generator_t<hw>::releaseFusedRemainders(GEMMState &state) {
state.remaindersFused[LoopN] = Subregister {};
}

static inline void releaseSLMRemainders(GEMMState &state) {
static inline void releaseCoopRemainders(GEMMState &state) {
for (LoopType loop : {LoopM, LoopN, LoopK})
if (state.remaindersSLM[loop] != state.remainders[loop])
state.ra.safeRelease(state.remaindersSLM[loop]);
if (state.remaindersCoop[loop] != state.remainders[loop])
state.ra.safeRelease(state.remaindersCoop[loop]);
}

template <HW hw>
Expand Down Expand Up @@ -12092,10 +12092,8 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
}

// k mask information.
// Should use state.remaindersSLM[Loop{M,N}] but in practice all those masks are
// already loaded.
Subregister rems[3]
= {state.remaindersSLM[LoopM], state.remaindersSLM[LoopN], state.K};
Subregister rems[3] = {
state.remaindersCoop[LoopM], state.remaindersCoop[LoopN], state.K};
int offsets[3] = {0, 0, -kOffset};

// If not changing between main loop and remainder, update k masks as needed and return.
Expand All @@ -12119,7 +12117,7 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
state.inputs.lda, false, true, AvoidFragment, state.Ai,
state.Ai_strategy, strategy, state, state.Ai_regCount);
if (!assignMasks(state.Ai_layoutRem, LoopM, LoopK, kMasksSLM, strategy,
state, true, &state.AiBi_masks))
state, true, &state.AB_masksCoop))
stub();
if (state.aioShare && state.Ao_regsRem.empty()
&& state.Ai_layoutRem[0].crosspack
Expand All @@ -12136,7 +12134,7 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
state.inputs.ldb, true, false, AvoidFragment, state.Bi,
state.Bi_strategy, strategy, state, state.Bi_regCount);
if (!assignMasks(state.Bi_layoutRem, LoopK, LoopN, kMasksSLM, strategy,
state, true, &state.AiBi_masks))
state, true, &state.AB_masksCoop))
stub();
if (state.bioShare && state.Bo_regsRem.empty()
&& state.Bi_layoutRem[0].crosspack
Expand Down Expand Up @@ -13545,6 +13543,26 @@ bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup(
state.effCoopA = effCoopSplitA(problem, strategy);
state.effCoopB = effCoopSplitB(problem, strategy);

// Prepare m/n remainders for m/n-cooperative operations.
for (LoopType loop : {LoopM, LoopN, LoopK})
state.remaindersCoop[loop] = state.remainders[loop];

if ((strategy.slmA || (strategy.prefetchA && strategy.cooperativePF))
&& remM_A && state.effCoopA == CoopSplit::MN) {
state.remaindersCoop[LoopM] = state.ra.alloc_sub<uint16_t>();
int32_t chunkM = unrollM / strategy.wg[LoopN];
emad(1 | sat, state.remaindersCoop[LoopM], state.remainders[LoopM],
-state.lidN.w(), chunkM, strategy, state);
}

if ((strategy.slmB || (strategy.prefetchB && strategy.cooperativePF))
&& remN_B && state.effCoopB == CoopSplit::MN) {
state.remaindersCoop[LoopN] = state.ra.alloc_sub<uint16_t>();
int32_t chunkN = unrollN / strategy.wg[LoopM];
emad(1 | sat, state.remaindersCoop[LoopN], state.remainders[LoopN],
-state.lidM.w(), chunkN, strategy, state);
}

// Prepare layouts for prefetch.
bool remM_Cp = remM_C && strategy.C.base.isStateless();
bool remN_Cp = remN_C && strategy.C.base.isStateless();
Expand Down Expand Up @@ -13617,9 +13635,6 @@ bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup(
}

// Prepare layouts and starting addresses for SLM copies and adjust problem.
for (LoopType loop : {LoopM, LoopN, LoopK})
state.remaindersSLM[loop] = state.remainders[loop];

if (strategy.slmBuffers > 0) {
int A_slmCP, B_slmCP;
int A_tileR, A_tileC, B_tileR, B_tileC;
Expand All @@ -13632,15 +13647,8 @@ bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup(
coopSplit(true, state.ma_slm, state.ka_slm, unrollM, unrollKSLM,
state.effCoopA, strategy.wg[LoopN], problem.A);

if (state.effCoopA == CoopSplit::MN) {
if (remM_A) {
state.remaindersSLM[LoopM] = state.ra.alloc_sub<uint16_t>();
emad(1 | sat, state.remaindersSLM[LoopM],
state.remainders[LoopM], -state.lidN.w(),
state.ma_slm, strategy, state);
}
if (state.effCoopA == CoopSplit::MN)
remK_A = remainderK && strategy.slmEarlyKMask;
}

if (strategy.slmATrans) {
A_slmCP = state.ka_slm;
Expand Down Expand Up @@ -13822,15 +13830,8 @@ bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup(
coopSplit(false, state.kb_slm, state.nb_slm, unrollKSLM, unrollN,
state.effCoopB, strategy.wg[LoopM], problem.B);

if (state.effCoopB == CoopSplit::MN) {
if (remN_B) {
state.remaindersSLM[LoopN] = state.ra.alloc_sub<uint16_t>();
emad(1 | sat, state.remaindersSLM[LoopN],
state.remainders[LoopN], -state.lidM.w(),
state.nb_slm, strategy, state);
}
if (state.effCoopB == CoopSplit::MN)
remK_B = remainderK && strategy.slmEarlyKMask;
}

if (strategy.slmBTrans) {
B_slmCP = state.kb_slm;
Expand Down Expand Up @@ -14170,21 +14171,21 @@ bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup(
// Try first without virtual flags and retry if needed.
// m/n cooperative SLM copies may use k masking; skip those masks for now.
auto &masks = state.AB_masks;
auto &imasks = state.AiBi_masks;
auto &Ai_masks = (state.effCoopA == CoopSplit::K) ? masks : imasks;
auto &Bi_masks = (state.effCoopB == CoopSplit::K) ? masks : imasks;
auto &masksCoop = state.AB_masksCoop;
auto &A_cmasks = (state.effCoopA == CoopSplit::K) ? masks : masksCoop;
auto &B_cmasks = (state.effCoopB == CoopSplit::K) ? masks : masksCoop;

auto assignAllMasks = [&]() {
return assignMasks(state.A_layout, LoopM, LoopK, masks, strategy, state)
&& assignMasks(
state.Ap_layout, LoopM, LoopK, masks, strategy, state)
&& assignMasks(state.Ap_layout, LoopM, LoopK, A_cmasks,
strategy, state)
&& assignMasks(state.Ai_layout, LoopM, LoopNone, A_cmasks,
strategy, state)
&& assignMasks(
state.B_layout, LoopK, LoopN, masks, strategy, state)
&& assignMasks(
state.Bp_layout, LoopK, LoopN, masks, strategy, state)
&& assignMasks(state.Ai_layout, LoopM, LoopNone, Ai_masks,
&& assignMasks(state.Bp_layout, LoopK, LoopN, B_cmasks,
strategy, state)
&& assignMasks(state.Bi_layout, LoopNone, LoopN, Bi_masks,
&& assignMasks(state.Bi_layout, LoopNone, LoopN, B_cmasks,
strategy, state);
};

Expand All @@ -14200,10 +14201,10 @@ bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup(
if (!success) return false;

loadMasks(masks, state.remainders, strategy, state);
loadMasks(imasks, state.remaindersSLM, strategy, state);
loadMasks(masksCoop, state.remaindersCoop, strategy, state);

if (!state.simd32KMasks)
releaseSLMRemainders(
releaseCoopRemainders(
state); /* may need SLM m/n remainders for k masking later */

// Apply panel masks, if defined, to all A/B blocks.
Expand Down Expand Up @@ -14410,7 +14411,7 @@ void gemm_kernel_generator_t<hw>::gemmAccumulateCTeardown(
// We're done with A and B. Free their address, data, and flag registers.
// Also done with loop counter.
safeReleaseMaskAssignments(state.AB_masks, state);
safeReleaseMaskAssignments(state.AiBi_masks, state);
safeReleaseMaskAssignments(state.AB_masksCoop, state);
safeReleaseRanges(state.A_addrs, state);
safeReleaseRanges(state.B_addrs, state);
safeReleaseRanges(state.Ai_addrs, state);
Expand All @@ -14433,7 +14434,7 @@ void gemm_kernel_generator_t<hw>::gemmAccumulateCTeardown(
state.ra.safeRelease(state.broadcast_regs);
safeReleaseRanges(state.tempMul_regs, state);
clearTokenAllocations(hw, state);
releaseSLMRemainders(state);
releaseCoopRemainders(state);

deduplicateScalar(state.lda, state);
deduplicateScalar(state.ldb, state);
Expand Down
4 changes: 2 additions & 2 deletions src/gpu/jit/gemm/gen_gemm_kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1203,14 +1203,14 @@ struct GEMMState : public CommonState {
GRFMultirange Ao_regsRem, Bo_regsRem;
GRFMultirange As_regs, Bs_regs; // A row sums/B column sums.
GRFMultirange Ap_regs, Bp_regs, Cp_regs; // A/B/C prefetch registers.
std::vector<MaskAssignment> AB_masks, AiBi_masks;
std::vector<MaskAssignment> AB_masks, AB_masksCoop;
ngen::GRFRange broadcast_regs;
std::vector<ngen::GRFRange> tempMul_regs;
ngen::Subregister i0, j0, h0; // d
ngen::Subregister remainders[3]; // d (todo: w)
ngen::Subregister remaindersFused[2]; // w
ngen::Subregister remaindersWG[2]; // d (todo: w)
ngen::Subregister remaindersSLM[3]; // d
ngen::Subregister remaindersCoop[3]; // d
ngen::Subregister remFusedStorage; // d
ngen::Subregister diagC; // d
SubregisterPair lda, ldb;
Expand Down

0 comments on commit 2b8f6b1

Please sign in to comment.