Skip to content

Commit

Permalink
src: cpu: aarch64: Enable jit bf16 -> f32 reorder (#2206)
Browse files Browse the repository at this point in the history
  • Loading branch information
aditew01 authored and mgouicem committed Dec 4, 2024
1 parent d13c966 commit 917dd13
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
14 changes: 14 additions & 0 deletions src/cpu/aarch64/jit_uni_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
&& !interim_f32_needed(p, false) && p.beta == 0.f)
|| (p.itype != bf16 && p.otype != bf16)
|| (p.itype == f32 && p.otype == bf16 && mayiuse_bf16()
&& p.beta == 0.f)
|| (p.itype == bf16 && p.otype == f32 && mayiuse_bf16()
&& p.beta == 0.f);

bool ok = true && p.ndims > 0
Expand Down Expand Up @@ -279,6 +281,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
/* do nothing */
break;
case s32: cvt_z_s32_f32(startIdx, regNum); break;
case bf16: cvt_v_bf16_fp32(startIdx, regNum); break;
case data_type::s8:
cvt_z_s8_s32(startIdx, regNum);
cvt_z_s32_f32(startIdx, regNum);
Expand Down Expand Up @@ -308,6 +311,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
cvt_z_s32_s8(startIdx, regNum);
if (idt == u8) cvt_z_u8_s8(startIdx, regNum);
break;
case data_type::bf16:
if (idt == f32) cvt_v_f32_bf16(startIdx, regNum);
break;
case u8:
if (idt == f32) cvt_z_f32_s32(startIdx, regNum);
if (utils::one_of(idt, f32, s32))
Expand Down Expand Up @@ -620,6 +626,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
/* do nothing */
break;
case s32: cvt_v_s32_f32(startIdx, regNum); break;
case bf16: cvt_v_bf16_fp32(startIdx, regNum); break;
case data_type::s8:
cvt_v_s8_s32(startIdx, regNum);
cvt_v_s32_f32(startIdx, regNum);
Expand All @@ -635,6 +642,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
auto cvt2odt = [=](const int startIdx, const int regNum,
data_type_t odt, data_type_t idt) {
switch (odt) {
case f32:
if (idt == bf16) cvt_v_bf16_fp32(startIdx, regNum);
break;
case s32:
if (idt == f32)
cvt_v_f32_s32(startIdx, regNum);
Expand Down Expand Up @@ -1691,6 +1701,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
UNROLL_INST2(bfcvtn, VReg4H(i), VReg4S(i));
}

void cvt_v_bf16_fp32(const size_t startIdx, const size_t regNum) {
UNROLL_INST2(shll, VReg4S(i), VReg4H(i), 16);
}

void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) {
cvt_z_b_s(startIdx, regNum);
UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp);
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/inputs/reorder/test_reorder_bfloat16
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# f32, s8, u8 <--> bf16
# f32, bf16, s8, u8 <--> bf16
--reset
--sdt=f32,s8,u8,f8_e5m2,f8_e4m3 --ddt=bf16
--sdt=f32,bf16,s8,u8,f8_e5m2,f8_e4m3 --ddt=bf16
--stag=abx
--dtag=aBx16b 2x64x14x14 2x56x14x14
--dtag=gOIhw16i16o 2x64x64x3x3 2x56x56x3x3
Expand Down

0 comments on commit 917dd13

Please sign in to comment.