From 280bd28fd8ba33aa99df1abb1de5ec782dab2159 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 16 Sep 2024 10:20:18 -0700 Subject: [PATCH] Masks: restrict rdivide field to powers of 2 --- .../gemm/generator/pieces/layout_setup.cxx | 21 ++++++++++--------- .../intel/jit/gemm/generator/pieces/masks.cxx | 14 ++++++------- .../gemm/generator/pieces/register_block.hpp | 2 +- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/gpu/intel/jit/gemm/generator/pieces/layout_setup.cxx b/src/gpu/intel/jit/gemm/generator/pieces/layout_setup.cxx index 0d939e743b2..43f2d19b119 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/layout_setup.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/layout_setup.cxx @@ -262,7 +262,7 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype vymask.bitRep = consecutive; vymask.maskRep = 1; vymask.rsize = *yblock; - vymask.rdivide = 1; + vymask.rshift = 0; } else if (logicalSlots < slots) { auto &fymask = block.colMajor ? block.rowMask.fixed : block.colMask.fixed; fymask.isFixed = true; @@ -279,7 +279,7 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype vxmask.bitRep = (block.simdSize > 16) ? 32 : 16; vxmask.maskRep = 1; vxmask.rsize = 1; - vxmask.rdivide = 1; + vxmask.rshift = 0; } else if (allowDesc && (channelScattered || astrategy.newDP) && *xblock > 1 && !byte) { fragment = std::min(*xblock, 4 * width / T); if (block.colMajor) // Clang can't handle the ternary operator equivalent of this. @@ -482,7 +482,7 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype vrmask.rsize = rblock; vrmask.bitRep = std::max(T.paddedSize() / maskGranularity, 1); vrmask.maskRep = cblock; - vrmask.rdivide = std::max(maskGranularity / T, 1); + vrmask.rshift = ilog2(std::max(maskGranularity / T, 1)); } } else { if (avoidFragment) { @@ -491,8 +491,8 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype vrmask.isFixed = false; vrmask.bitRep = 0; /* will be filled in later */ vrmask.maskRep = 1; - vrmask.rdivide = 1; vrmask.rsize = 1; + vrmask.rshift = 0; } else { // Fragment it. Could actually handle rowFragment = 2 by changing descriptor. block.rowFragment = 1; @@ -520,7 +520,7 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype vcmask.rsize = cblock; vcmask.bitRep = std::max(T.paddedSize() / maskGranularity, 1); vcmask.maskRep = rblock; - vcmask.rdivide = std::max(maskGranularity / T, 1); + vcmask.rshift = ilog2(std::max(maskGranularity / T, 1)); } } else { if (avoidFragment) { @@ -529,8 +529,8 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype vcmask.isFixed = false; vcmask.bitRep = 0; vcmask.maskRep = 1; - vcmask.rdivide = 1; vcmask.rsize = 1; + vcmask.rshift = 0; } else { // Fragment it. Could actually handle colFragment = 2 by changing descriptor. block.colFragment = 1; @@ -719,7 +719,8 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype auto &vxmask = block.colMajor ? block.rowMask.variable : block.colMask.variable; vxmask.isFixed = false; vxmask.bitRep = block.simdSize; - vxmask.maskRep = vxmask.rdivide = vxmask.rsize = 1; + vxmask.maskRep = vxmask.rsize = 1; + vxmask.rshift = 0; } if (remainderY) { @@ -728,7 +729,7 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype vymask.bitRep = xCacheLines; vymask.maskRep = 1; vymask.rsize = yblock; - vymask.rdivide = 1; + vymask.rshift = 0; } break; } @@ -739,13 +740,13 @@ bool BLASKernelGenerator::getBlockInfo(Type T, const MatrixAddressing &atype if (block.rowMask && !block.rowMask.fixed.isFixed) { if (vrmask.rsize == 0) vrmask.rsize = rblock; - vrmask.maskRep = std::min(vrmask.maskRep, std::max(1, vrmask.rdivide * block.simdSize / (vrmask.bitRep * vrmask.rsize))); + vrmask.maskRep = std::min(vrmask.maskRep, std::max(1, (block.simdSize << vrmask.rshift) / (vrmask.bitRep * vrmask.rsize))); block.noRowsOK = true; // All-zero masks are always OK. } if (block.colMask && !block.colMask.fixed.isFixed) { if (vcmask.rsize == 0) vcmask.rsize = cblock; - vcmask.maskRep = std::min(vcmask.maskRep, std::max(1, vcmask.rdivide * block.simdSize / (vcmask.bitRep * vcmask.rsize))); + vcmask.maskRep = std::min(vcmask.maskRep, std::max(1, (block.simdSize << vcmask.rshift) / (vcmask.bitRep * vcmask.rsize))); block.noColsOK = true; } diff --git a/src/gpu/intel/jit/gemm/generator/pieces/masks.cxx b/src/gpu/intel/jit/gemm/generator/pieces/masks.cxx index 756be43e2ff..f1f6aebd39f 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/masks.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/masks.cxx @@ -127,7 +127,7 @@ void BLASKernelGenerator::loadMask(MaskAssignment assignment, Subregister in // Load a variable mask, which requires some minor bit-twiddling. auto &vmask = assignment.mask.variable; - uint32_t rsizeScaled = vmask.rsize / vmask.rdivide; + uint32_t rsizeScaled = vmask.rsize >> vmask.rshift; uint32_t maskLen = vmask.bitRep * vmask.maskRep * rsizeScaled; uint32_t fullMask = (uint64_t(1) << maskLen) - 1; uint32_t rep1Mask = (uint64_t(1) << (vmask.bitRep * rsizeScaled)) - 1; @@ -136,7 +136,7 @@ void BLASKernelGenerator::loadMask(MaskAssignment assignment, Subregister in auto flagType = flag.getType(); auto mask0Type = getBytes(flagType) >= 4 ? DataType::uq : flagType; - if (vmask.rsize == 1 && vmask.rdivide == 1) { + if (vmask.rsize == 1 && vmask.rshift == 0) { // Simple threshold comparison. offset += assignment.offset; if (flag.isARF()) @@ -152,11 +152,11 @@ void BLASKernelGenerator::loadMask(MaskAssignment assignment, Subregister in auto mask0 = state.ra.alloc_sub(mask0Type, getHint(HintType::Bank1)); auto mask = mask0.reinterpret(0, flagType); auto mindex = index; + auto rdivide = 1 << vmask.rshift; - if (vmask.rdivide > 1) { - if (!is_zero_or_pow2(vmask.rdivide)) stub(); - add(1 | sat, temp, mindex, -offset + vmask.rdivide - 1); - shr(1, temp, temp, uint16_t(ilog2(vmask.rdivide))); + if (vmask.rshift) { + add(1 | sat, temp, mindex, -offset + rdivide - 1); + shr(1, temp, temp, uint16_t(vmask.rshift)); mindex = temp; offset = 0; } @@ -169,7 +169,7 @@ void BLASKernelGenerator::loadMask(MaskAssignment assignment, Subregister in mulConstant(1, temp, mindex, vmask.bitRep); mindex = temp; } - uint16_t tshift = vmask.bitRep * (rsizeScaled + div_up(assignment.offset + offset, vmask.rdivide)); + uint16_t tshift = vmask.bitRep * (rsizeScaled + div_up(assignment.offset + offset, rdivide)); add(1 | sat, temp, -mindex, tshift); if (tshift >= 32) min_(1, temp, temp, vmask.bitRep * rsizeScaled); // Ensure shift count doesn't overflow. diff --git a/src/gpu/intel/jit/gemm/generator/pieces/register_block.hpp b/src/gpu/intel/jit/gemm/generator/pieces/register_block.hpp index b1e835a0663..9e811646c91 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/register_block.hpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/register_block.hpp @@ -34,7 +34,7 @@ struct MaskInfo { struct { uint8_t isFixed : 1; // = false (variable mask) uint8_t reverse : 1; // True to reverse mask. - uint8_t rdivide : 6; // Amount by which to divide index before forming mask. Fractions are rounded up. + uint8_t rshift : 6; // Power of 2 by which to divide index before forming mask. Fractions are rounded up. // Note maskRep * bitRep * (rsize >> rshift) = # mask bits. uint8_t rsize; // Maximum remainder value. (e.g. 16 if we need the last 4 bits of the index). uint8_t maskRep; // # of repetitions of mask pattern.