diff --git a/src/gpu/intel/jit/gemm/gen_gemm_kernel_generator.cpp b/src/gpu/intel/jit/gemm/gen_gemm_kernel_generator.cpp index b1a79d7bc48..bcf2cfd455e 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm_kernel_generator.cpp +++ b/src/gpu/intel/jit/gemm/gen_gemm_kernel_generator.cpp @@ -5847,9 +5847,12 @@ static inline bool canRelAddr(const RegisterBlock &blockSrc, } static inline int block2DWidthAlignment(Type T, const RegisterBlock &block, + const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy) { // Block 2D width must be DW-aligned, but generally use QW alignment for better performance for reads. - return ((astrategy.noExtraPad || block.writable) ? 4 : 8); + return ((astrategy.noExtraPad || block.writable || atype.alignment % 8) + ? 4 + : 8); } static inline int block2DBaseAlignment(HW hw, int stepping) { @@ -6071,7 +6074,7 @@ void gemm_kernel_generator_t::setupAddr(Type T, const GRFRange &addr, if (doBaseAdjust && !astrategy.address2D) stub(); Subregister baStorage, baseAdjust, baseAdjustElems; - int widthAlign = block2DWidthAlignment(T, block, astrategy); + int widthAlign = block2DWidthAlignment(T, block, atype, astrategy); if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u); @@ -6828,6 +6831,7 @@ void gemm_kernel_generator_t::remaskLayout(Type T, int index, bool column, } static bool needsRemask(Type T, bool column, const RegisterBlock &block, + const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) { if (!ignoreMasks) if (column ? !block.remainderC : !block.remainderR) return false; @@ -6839,8 +6843,8 @@ static bool needsRemask(Type T, bool column, const RegisterBlock &block, int maskGranularity = block.ebytes; if (block.ebytes >= 16) maskGranularity = 4; if (block2DRemask) - maskGranularity = std::max( - maskGranularity, block2DWidthAlignment(T, block, astrategy)); + maskGranularity = std::max(maskGranularity, + block2DWidthAlignment(T, block, atype, astrategy)); if (ignoreMasks && !(block2DRemask && astrategy.address2D)) maskGranularity = 256; @@ -6848,10 +6852,11 @@ static bool needsRemask(Type T, bool column, const RegisterBlock &block, } static bool needsRemask(Type T, bool column, - const vector &layout, + const vector &layout, const MatrixAddressing &atype, const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) { for (auto &block : layout) - if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true; + if (needsRemask(T, column, block, atype, astrategy, ignoreMasks)) + return true; return false; } @@ -14483,11 +14488,11 @@ void gemm_kernel_generator_t::kLoopActivateSLMRemainder(bool active, bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded; bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded; slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy - && needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai_strategy, - asIfMaskedAi); + && needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai, + state.Ai_strategy, asIfMaskedAi); slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy - && needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi_strategy, - asIfMaskedBi); + && needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi, + state.Bi_strategy, asIfMaskedBi); } static inline void kLoopModifiedFlagAP(GEMMState &state) { @@ -15341,11 +15346,11 @@ void gemm_kernel_generator_t::kLoop(KLoop type, const GEMMProblem &problem, // A/B remasking in k dimension, during remainder handling. bool remaskA = !slmA && readA && (minOPCount > 1) - && needsRemask(Ta_load, true, state.A_layoutRem, strategy.A, - state.A_lateKRem); + && needsRemask(Ta_load, true, state.A_layoutRem, problem.A, + strategy.A, state.A_lateKRem); bool remaskB = !slmB && readB && (minOPCount > 1) - && needsRemask(Tb_load, false, state.B_layoutRem, strategy.B, - state.B_lateKRem); + && needsRemask(Tb_load, false, state.B_layoutRem, problem.B, + strategy.B, state.B_lateKRem); if (Ta.isInteger() && Tb.isInteger() && !calcASums && !calcBSums) { // Only need to remask one operand for integer A/B. Choose the smaller one.