diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp index 9239ace9ae1..23832bd385d 100644 --- a/src/cpu/aarch64/jit_uni_reorder.cpp +++ b/src/cpu/aarch64/jit_uni_reorder.cpp @@ -166,6 +166,8 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { && !interim_f32_needed(p, false) && p.beta == 0.f) || (p.itype != bf16 && p.otype != bf16) || (p.itype == f32 && p.otype == bf16 && mayiuse_bf16() + && p.beta == 0.f) + || (p.itype == bf16 && p.otype == f32 && mayiuse_bf16() && p.beta == 0.f); bool ok = true && p.ndims > 0 @@ -279,6 +281,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { /* do nothing */ break; case s32: cvt_z_s32_f32(startIdx, regNum); break; + case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; case data_type::s8: cvt_z_s8_s32(startIdx, regNum); cvt_z_s32_f32(startIdx, regNum); @@ -308,6 +311,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { cvt_z_s32_s8(startIdx, regNum); if (idt == u8) cvt_z_u8_s8(startIdx, regNum); break; + case data_type::bf16: + if (idt == f32) cvt_v_f32_bf16(startIdx, regNum); + break; case u8: if (idt == f32) cvt_z_f32_s32(startIdx, regNum); if (utils::one_of(idt, f32, s32)) @@ -620,6 +626,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { /* do nothing */ break; case s32: cvt_v_s32_f32(startIdx, regNum); break; + case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; case data_type::s8: cvt_v_s8_s32(startIdx, regNum); cvt_v_s32_f32(startIdx, regNum); @@ -635,6 +642,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { auto cvt2odt = [=](const int startIdx, const int regNum, data_type_t odt, data_type_t idt) { switch (odt) { + case f32: + if (idt == bf16) cvt_v_bf16_fp32(startIdx, regNum); + break; case s32: if (idt == f32) cvt_v_f32_s32(startIdx, regNum); @@ -1691,6 +1701,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { UNROLL_INST2(bfcvtn, VReg4H(i), VReg4S(i)); } + void cvt_v_bf16_fp32(const size_t startIdx, const size_t regNum) { + UNROLL_INST2(shll, VReg4S(i), VReg4H(i), 16); + } + void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) { cvt_z_b_s(startIdx, regNum); UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); diff --git a/tests/benchdnn/inputs/reorder/test_reorder_bfloat16 b/tests/benchdnn/inputs/reorder/test_reorder_bfloat16 index 87d47073893..4563b1a1578 100644 --- a/tests/benchdnn/inputs/reorder/test_reorder_bfloat16 +++ b/tests/benchdnn/inputs/reorder/test_reorder_bfloat16 @@ -1,6 +1,6 @@ -# f32, s8, u8 <--> bf16 +# f32, bf16, s8, u8 <--> bf16 --reset ---sdt=f32,s8,u8,f8_e5m2,f8_e4m3 --ddt=bf16 +--sdt=f32,bf16,s8,u8,f8_e5m2,f8_e4m3 --ddt=bf16 --stag=abx --dtag=aBx16b 2x64x14x14 2x56x14x14 --dtag=gOIhw16i16o 2x64x64x3x3 2x56x56x3x3