Skip to content

Commit

Permalink
cpu : aarch64 : reorder : reenabled bf16 jit uni reorders
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-fuj authored and Radu2k committed Dec 2, 2024
1 parent e72f6bd commit db5e699
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#===============================================================================

build
build*
external
.vs
.vscode
Expand Down
34 changes: 20 additions & 14 deletions src/cpu/aarch64/jit_uni_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,20 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
static bool applicable(const prb_t &p) {
using namespace data_type;

bool bf16_ok
= (mayiuse_bf16() && (p.itype == bf16) && (p.otype == bf16)
&& !interim_f32_needed(p, false) && p.beta == 0.f)
|| (p.itype != bf16 && p.otype != bf16)
|| (p.itype == f32 && p.otype == bf16 && mayiuse_bf16()
&& p.beta == 0.f);

bool ok = true && p.ndims > 0
&& utils::one_of(p.itype, f32, s32, data_type::s8, u8)
&& utils::one_of(p.itype, f32, bf16, s32, data_type::s8, u8)
&& utils::one_of(p.otype, f32, bf16, s32, data_type::s8, u8)
&& utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
&& utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
&& simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p)
&& IMPLICATION(
p.otype == bf16, p.itype == f32 && mayiuse_bf16());
&& bf16_ok;

return ok;
}
Expand Down Expand Up @@ -702,7 +708,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
const int load_tail_step
= !can_load_xmm && can_store_xmm ? ur_step : load_step;

const bool interim_f32 = interim_f32_needed();
const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_);

const bool need_saturation
= (utils::one_of(prb_.otype, u8, data_type::s8, s32)
Expand Down Expand Up @@ -1285,17 +1291,17 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
}
}

bool interim_f32_needed() {
static bool interim_f32_needed(const prb_t &prb, bool compensation_needed) {
using namespace data_type;

return utils::one_of(f32, prb_.itype, prb_.otype)
|| prb_.src_scale_type != scale_type_t::NONE
|| prb_.dst_scale_type != scale_type_t::NONE || prb_.beta != 0.f
|| ((prb_.req_src_zp || prb_.req_dst_zp)
? !(prb_.itype == s32 && prb_.otype == s32)
bool ret = utils::one_of(f32, prb.itype, prb.otype)
|| prb.src_scale_type != scale_type_t::NONE
|| prb.dst_scale_type != scale_type_t::NONE || prb.beta != 0.f
|| ((prb.req_src_zp || prb.req_dst_zp)
? !(prb.itype == s32 && prb.otype == s32)
: false)
|| (prb_.itype != f32 && compensation_needed_)
|| prb_.scale_adjust != 1.f;
|| (prb.itype != f32 && compensation_needed)
|| prb.scale_adjust != 1.f;
return ret;
}

void process_unroll_generic(
Expand All @@ -1313,7 +1319,7 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {

int curr = 0; // will switch between 0 and 1

const bool interim_f32 = interim_f32_needed();
const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_);

if (prb_.req_src_zp) {
add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0);
Expand Down

0 comments on commit db5e699

Please sign in to comment.