Skip to content

Commit

Permalink
gpu: jit: gemm: x32->f16 conversion fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad authored and karturov committed Oct 7, 2022
1 parent c8adc17 commit 3693fbf
Showing 1 changed file with 190 additions and 178 deletions.
368 changes: 190 additions & 178 deletions src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,183 @@ static inline int eusPerSubslice(HW hw) {
}
}

static inline bool canDualGRF(
HW hw, DataType dt, const CommonStrategy &strategy) {
return (strategy.dualGRF && (elementsPerGRF(hw, dt) < 32));
}

// Perform a binary register-wise operation.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
const GRFMultirange &r2, const CommonStrategy &strategy, F f) {
int ne = elementsPerGRF(hw, dt);
int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
int len = r1.getLen();

for (int rr = 0; rr < len;) {
int nr = std::min<int>(len - rr, rstride);
if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)) nr = 1;
f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt));
rr += nr;
}
}

// Perform a ternary register-wise operation.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
const GRFMultirange &r2, const GRFMultirange &r3,
const CommonStrategy &strategy, F f) {
int ne = elementsPerGRF(hw, dt);
int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
int len = r1.getLen();

for (int rr = 0; rr < len;) {
int nr = std::min<int>(len - rr, rstride);
if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)
|| !r3.contiguous(rr, nr))
nr = 1;
f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt));
rr += nr;
}
}

// Perform a quaternary register-wise operation.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
const GRFMultirange &r2, const GRFMultirange &r3,
const GRFMultirange &r4, const CommonStrategy &strategy, F f) {
int ne = elementsPerGRF(hw, dt);
int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
int len = r1.getLen();

for (int rr = 0; rr < len;) {
int nr = std::min<int>(len - rr, rstride);
if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)
|| !r3.contiguous(rr, nr) || !r4.contiguous(rr, nr))
nr = 1;
f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt),
r4[rr].retype(dt));
rr += nr;
}
}

// Perform a unary register-wise operation on a register block.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &regs,
const vector<RegisterBlock> &layout, const CommonStrategy &strategy,
F f) {
int curReg = 0, curOff = 0, curBytes = 0;
auto ebytes = getBytes(dt);

auto map1 = [&]() {
curOff &= -ebytes;
curBytes &= -ebytes;
while (curBytes) {
int maxBytes;
if (curOff & (GRF::bytes(hw) - 1))
maxBytes = GRF::bytes(hw) - curOff;
else
maxBytes = (canDualGRF(hw, dt, strategy) ? 2 : 1)
* GRF::bytes(hw);

auto nbytes = rounddown_pow2(std::min(maxBytes, curBytes));
auto ne = std::min<int>(32, nbytes / ebytes);
nbytes = ne * ebytes;

auto reg = regs[curOff >> GRF::log2Bytes(hw)].sub(
(curOff & (GRF::bytes(hw) - 1)) / ebytes, dt)(1);

f(ne, reg);

curBytes -= nbytes;
curOff += nbytes;
}
};

for (auto &block : layout) {
int endReg
= (curOff + curBytes + block.bytes - 1) >> GRF::log2Bytes(hw);
if ((block.offsetBytes == curOff + curBytes)
&& regs.contiguous(curReg, endReg - curReg + 1))
curBytes += block.bytes;
else {
map1();
curOff = block.offsetBytes;
curReg = curOff >> GRF::log2Bytes(hw);
curBytes = block.bytes;
}
}

map1();
}

template <typename T, typename F>
static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2,
const CommonStrategy &strategy, F f) {
map(hw, getDataType<T>(), r1, r2, strategy, f);
}

template <typename T, typename F>
static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2,
const GRFMultirange &r3, const CommonStrategy &strategy, F f) {
map(hw, getDataType<T>(), r1, r2, r3, strategy, f);
}

template <typename T, typename F>
static inline void map(HW hw, const GRFMultirange &regs,
const vector<RegisterBlock> &layout, const CommonStrategy &strategy,
F f) {
map(hw, getDataType<T>(), regs, layout, strategy, f);
}

template <typename... Targs>
static inline void map(HW hw, Type T, Targs... args) {
map(hw, T.ngen(), args...);
}

// Move subregister to another pipe.
static inline void movePipes(Subregister &s, bool sizeCanChange = true) {
DataType type = s.getType();

switch (type) {
case DataType::bf:
case DataType::hf: type = DataType::uw; break;
case DataType::tf32:
case DataType::f: type = DataType::ud; break;
case DataType::df:
if (sizeCanChange) type = DataType::ud;
break;
case DataType::w:
case DataType::uw: type = DataType::hf; break;
case DataType::d:
case DataType::ud: type = DataType::f; break;
case DataType::q:
case DataType::uq:
if (sizeCanChange) type = DataType::f;
break;
default: break;
}

s = s.reinterpret(0, type);
}

// Move register region to integer pipe.
static inline void moveToIntPipe(int esize, RegData &s) {
switch (s.getType()) {
case DataType::bf:
case DataType::hf: s.setType(DataType::uw); break;
case DataType::q:
case DataType::uq:
case DataType::tf32:
case DataType::f: s.setType(DataType::ud); break;
case DataType::df:
s.setType(DataType::uq);
EmulationImplementation::makeDWPair(s, esize);
break;
default: break;
}
}

void RegisterBlock::calcBytes(
Type T, const MatrixAddressingStrategy &astrategy) {
if (astrategy.newDP && astrategy.prefetch)
Expand Down Expand Up @@ -545,6 +722,15 @@ void gemm_kernel_generator_t<hw>::emov(const ngen::InstructionModifier &mod,
src0.setType(DataType::f);
}

if (hw >= HW::XeHP
&& one_of(src0.getType(), DataType::hf, DataType::f, DataType::bf)
&& src0.getType() == dst.getType()
&& ((src0.getHS() != dst.getHS())
|| (src0.getOffset() != dst.getOffset()))) {
moveToIntPipe(mod.getExecSize(), dst);
moveToIntPipe(mod.getExecSize(), src0);
}

if (hw < HW::XeHP && dst.getType() == DataType::f
&& src0.getType() == DataType::bf) {
dst.setType(DataType::ud);
Expand Down Expand Up @@ -967,183 +1153,6 @@ Subregister gemm_kernel_generator_t<hw>::copySubregister(
return copy;
}

static inline bool canDualGRF(
HW hw, DataType dt, const CommonStrategy &strategy) {
return (strategy.dualGRF && (elementsPerGRF(hw, dt) < 32));
}

// Perform a binary register-wise operation.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
const GRFMultirange &r2, const CommonStrategy &strategy, F f) {
int ne = elementsPerGRF(hw, dt);
int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
int len = r1.getLen();

for (int rr = 0; rr < len;) {
int nr = std::min<int>(len - rr, rstride);
if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)) nr = 1;
f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt));
rr += nr;
}
}

// Perform a ternary register-wise operation.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
const GRFMultirange &r2, const GRFMultirange &r3,
const CommonStrategy &strategy, F f) {
int ne = elementsPerGRF(hw, dt);
int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
int len = r1.getLen();

for (int rr = 0; rr < len;) {
int nr = std::min<int>(len - rr, rstride);
if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)
|| !r3.contiguous(rr, nr))
nr = 1;
f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt));
rr += nr;
}
}

// Perform a quaternary register-wise operation.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
const GRFMultirange &r2, const GRFMultirange &r3,
const GRFMultirange &r4, const CommonStrategy &strategy, F f) {
int ne = elementsPerGRF(hw, dt);
int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
int len = r1.getLen();

for (int rr = 0; rr < len;) {
int nr = std::min<int>(len - rr, rstride);
if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)
|| !r3.contiguous(rr, nr) || !r4.contiguous(rr, nr))
nr = 1;
f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt),
r4[rr].retype(dt));
rr += nr;
}
}

// Perform a unary register-wise operation on a register block.
template <typename F>
static inline void map(HW hw, DataType dt, const GRFMultirange &regs,
const vector<RegisterBlock> &layout, const CommonStrategy &strategy,
F f) {
int curReg = 0, curOff = 0, curBytes = 0;
auto ebytes = getBytes(dt);

auto map1 = [&]() {
curOff &= -ebytes;
curBytes &= -ebytes;
while (curBytes) {
int maxBytes;
if (curOff & (GRF::bytes(hw) - 1))
maxBytes = GRF::bytes(hw) - curOff;
else
maxBytes = (canDualGRF(hw, dt, strategy) ? 2 : 1)
* GRF::bytes(hw);

auto nbytes = rounddown_pow2(std::min(maxBytes, curBytes));
auto ne = std::min<int>(32, nbytes / ebytes);
nbytes = ne * ebytes;

auto reg = regs[curOff >> GRF::log2Bytes(hw)].sub(
(curOff & (GRF::bytes(hw) - 1)) / ebytes, dt)(1);

f(ne, reg);

curBytes -= nbytes;
curOff += nbytes;
}
};

for (auto &block : layout) {
int endReg
= (curOff + curBytes + block.bytes - 1) >> GRF::log2Bytes(hw);
if ((block.offsetBytes == curOff + curBytes)
&& regs.contiguous(curReg, endReg - curReg + 1))
curBytes += block.bytes;
else {
map1();
curOff = block.offsetBytes;
curReg = curOff >> GRF::log2Bytes(hw);
curBytes = block.bytes;
}
}

map1();
}

template <typename T, typename F>
static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2,
const CommonStrategy &strategy, F f) {
map(hw, getDataType<T>(), r1, r2, strategy, f);
}

template <typename T, typename F>
static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2,
const GRFMultirange &r3, const CommonStrategy &strategy, F f) {
map(hw, getDataType<T>(), r1, r2, r3, strategy, f);
}

template <typename T, typename F>
static inline void map(HW hw, const GRFMultirange &regs,
const vector<RegisterBlock> &layout, const CommonStrategy &strategy,
F f) {
map(hw, getDataType<T>(), regs, layout, strategy, f);
}

template <typename... Targs>
static inline void map(HW hw, Type T, Targs... args) {
map(hw, T.ngen(), args...);
}

// Move subregister to another pipe.
static inline void movePipes(Subregister &s, bool sizeCanChange = true) {
DataType type = s.getType();

switch (type) {
case DataType::bf:
case DataType::hf: type = DataType::uw; break;
case DataType::tf32:
case DataType::f: type = DataType::ud; break;
case DataType::df:
if (sizeCanChange) type = DataType::ud;
break;
case DataType::w:
case DataType::uw: type = DataType::hf; break;
case DataType::d:
case DataType::ud: type = DataType::f; break;
case DataType::q:
case DataType::uq:
if (sizeCanChange) type = DataType::f;
break;
default: break;
}

s = s.reinterpret(0, type);
}

// Move register region to integer pipe.
static inline void moveToIntPipe(int esize, RegData &s) {
switch (s.getType()) {
case DataType::bf:
case DataType::hf: s.setType(DataType::uw); break;
case DataType::q:
case DataType::uq:
case DataType::tf32:
case DataType::f: s.setType(DataType::ud); break;
case DataType::df:
s.setType(DataType::uq);
EmulationImplementation::makeDWPair(s, esize);
break;
default: break;
}
}

// Set a matrix to zero.
template <HW hw>
void gemm_kernel_generator_t<hw>::zeroMatrix(
Expand Down Expand Up @@ -20132,7 +20141,10 @@ bool gemm_kernel_generator_t<hw>::copyRegisters(Type Ts, Type Td,
// Check if separate conversions are needed due to size changes.
auto sconvertCP = (Ts.size() / Td.size());
bool sconvert = (Td.size() == 1 && Ts.size() > 1
&& dcrosspack != sconvertCP);
&& dcrosspack != sconvertCP)
|| (Td.size() == 2 && Td.isFP() && !Ts.isFP()
&& dcrosspack != sconvertCP
&& hw > HW::Gen9);
if (sconvert && preserveSrc) stub();
auto sregConverted = sconvert
? sreg.reinterpret(0, Td.real().ngen())(sconvertCP)
Expand Down

0 comments on commit 3693fbf

Please sign in to comment.