Skip to content

Commit

Permalink
cpu: x64: simplify alpha and beta parameters in brgconv postops kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
kwiersch authored and tprimak committed Jan 9, 2023
1 parent 0a8116b commit 1b13037
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
bcfg->LDD = (is_init && jcp.use_buffer) ? jcp.LDC : jcp.LDD;
bcfg->dt_c = (!is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // inp
bcfg->dt_d = (is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // out
bcfg->alpha = is_init ? 0 : 1;
bcfg->alpha = !is_init && IMPLICATION(jcp.with_sum, jcp.use_buffer);
bcfg->beta = is_init ? 0 : 1;
CHECK(safe_ptr_assign(kernels_po_[ker_idx],
new jit_brgemm_kernel_post_ops<isa>(jcp, *bcfg, *_pd->attr())));
Expand Down
68 changes: 29 additions & 39 deletions src/cpu/x64/jit_brgemm_post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
brg.attr->post_ops_,
memory_desc_wrapper(brg.dst_md))) {

if ((jcp.with_sum && brg.beta != 0)
|| ((jcp.with_binary || jcp.with_eltwise) && brg.alpha != 0)) {
if (brg.beta != 0) {
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = true;
static constexpr bool use_exact_tail_scalar_bcast = false;
Expand All @@ -355,7 +354,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
k_tail_mask, use_exact_tail_scalar_bcast};
const binary_injector::static_params_t bsp {this->param1, rhs_sp};

const bool save_state = (brg.alpha != 0) && jcp.with_eltwise;
const bool save_state = jcp.with_eltwise;
const auto &reserved_eltwise_gpr = reg_reserved_eltwise;
const auto reserved_eltwise_maskr = Xbyak::Opmask(1);

Expand Down Expand Up @@ -579,7 +578,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
};

if (jcp.with_sum && brg.beta != 0) {
if (jcp.with_sum) {
postops_injector_->set_lambda_injector(
primitive_kind::sum, sum_injector);
}
Expand Down Expand Up @@ -607,7 +606,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
void apply_comp(int m_block, int n_block, int tail = 0) {
auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask;

if (brg.alpha != 0 && brg.zp_type_a != brgemm_broadcast_t::none) {
if (brg.zp_type_a != brgemm_broadcast_t::none) {
auto vmm_zp_a_val = vmm_tmp(1);
mov(reg_zp_a_val, ptr[rsp + reg_zp_a_val_offs_]);
vpbroadcastd(vmm_zp_a_val, reg_zp_a_val.cvt32());
Expand All @@ -629,7 +628,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
}

if (brg.alpha != 0 && brg.req_s8s8_compensation) {
if (brg.req_s8s8_compensation) {
mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]);
for (int n = 0; n < n_block; n++) {
auto vmm_comp = vmm_tmp(0);
Expand Down Expand Up @@ -663,30 +662,20 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
void apply_post_ops(int m_block, int n_block, int tail = 0) {
const auto vector = [=](int m, int n) { return Vmm(m * n_block + n); };
auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask;
const auto &p = attr.post_ops_;
const int sum_idx = p.find(primitive_kind::sum);
const auto req_comp = brg.is_int8 && brg.alpha != 0
const auto req_comp = brg.is_int8 && brg.beta != 0
&& (brg.req_s8s8_compensation
|| brg.zp_type_a != brgemm_broadcast_t::none);

// brg.alpha == 0 means no read from input, no bias, no eltwise - just
// initialize registers by zero at the beginning of kernel
// brg.beta == 0 means no sum - just registers write to output
// brg.alpha == 0 means initialize registers, 1 means read from input
// brg.beta == 0 means skip postwork, 1 means do postwork
// req_comp == true -> convert accumulated values to f32 after applying
// compensation to avoid the loss of accuracy when converting s32 to f32
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
if (brg.alpha == 0) {
if (sum_idx != -1 && brg.beta != 0) {
// if sum then have to init vmm each time
uni_vpxor(vector(m, n), vector(m, n), vector(m, n));
}
} else if (!IMPLICATION(jcp.with_sum, jcp.use_buffer)) {
if (sum_idx != -1 && brg.beta != 0) {
// if sum without buffer then have to init vmm each time
uni_vpxor(vector(m, n), vector(m, n), vector(m, n));
}
} else {
if (brg.alpha == 0 && brg.beta != 0) {
// if postwork then have to init vmm each time
uni_vpxor(vector(m, n), vector(m, n), vector(m, n));
} else if (brg.alpha != 0) {
auto inp_addr = ptr[aux_reg_in
+ inp_typesize_ * (m * brg.LDC + n * brg.ld_block)];
cvt2ps(inp_dt_, vector(m, n), inp_addr, tail, false, k_mask,
Expand All @@ -696,7 +685,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

if (req_comp) maybe_apply_comp(m_block, n_block, tail);

if (brg.alpha != 0 && jcp.with_bias) {
if (brg.beta != 0 && jcp.with_bias) {
for (int n = 0; n < n_block; n++) {
auto vmm_bias = vmm_tmp(0);
auto bias_addr = ptr[aux_reg_bias
Expand All @@ -708,7 +697,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
}

if (brg.alpha != 0) {
if (brg.beta != 0) {
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
const auto addr = ptr[aux_reg_scales
Expand All @@ -727,7 +716,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

if (postops_injector_) inject_attr_postops(m_block, n_block, tail);

if (brg.alpha != 0 && brg.zp_type_c != brgemm_broadcast_t::none) {
if (brg.beta != 0 && brg.zp_type_c != brgemm_broadcast_t::none) {
mov(aux_reg_zp_c_values, ptr[rsp + aux_reg_zp_c_values_offs_]);
auto vmm_zp_c = vmm_tmp(0);
if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) {
Expand Down Expand Up @@ -808,8 +797,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

void loop_by_N(int m_block, int nb2, int nb2_tail, int nb_tail) {

if (brg.alpha) {
mov(aux_reg_in, reg_in);
if (brg.alpha) { mov(aux_reg_in, reg_in); }
if (brg.beta != 0) {
if (jcp.with_bias) mov(aux_reg_bias, reg_bias);
if (brg.zp_type_c != brgemm_broadcast_t::none) {
mov(aux_reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]);
Expand All @@ -835,7 +824,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
add(aux_reg_out, out_typesize_ * oc_l_offset);
if (brg.alpha != 0) {
add(aux_reg_in, inp_typesize_ * oc_l_offset);

}
if (brg.beta != 0) {
if (jcp.with_bias)
add(aux_reg_bias, bia_typesize_ * oc_l_offset);
if (brg.zp_type_c != brgemm_broadcast_t::none) {
Expand Down Expand Up @@ -866,6 +856,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
add(aux_reg_out, out_typesize_ * oc_l_offset);
if (brg.alpha != 0) {
add(aux_reg_in, inp_typesize_ * oc_l_offset);
}
if (brg.beta != 0) {
if (jcp.with_bias)
add(aux_reg_bias, bia_typesize_ * oc_l_offset);
if (brg.zp_type_c != brgemm_broadcast_t::none) {
Expand All @@ -892,8 +884,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
if (nb_tail > 0) {
apply_post_ops(m_block, 1, nb_tail);

if (brg.alpha != 0) {
add(aux_reg_in, inp_typesize_ * (nb_tail));
if (brg.alpha != 0) { add(aux_reg_in, inp_typesize_ * (nb_tail)); }
if (brg.beta != 0) {
if (jcp.with_bias) add(aux_reg_bias, bia_typesize_ * (nb_tail));
if (brg.zp_type_c != brgemm_broadcast_t::none) {
mov(aux_reg_zp_c_values,
Expand Down Expand Up @@ -948,8 +940,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
kmovq(k_tail_mask, reg_mask);
}

if (brg.alpha != 0) {
mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]);
if (brg.alpha != 0) { mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]); }
if (brg.beta != 0) {
mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]);
mov(reg_apply_comp, ptr[param1 + GET_OFF(apply_comp)]);
mov(ptr[rsp + reg_apply_comp_offs_], reg_apply_comp);
Expand All @@ -973,10 +965,9 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
mov(reg_out, ptr[param1 + GET_OFF(ptr_out)]);

// brg.alpha == 0 means no read from input, no bias, no eltwise - just
// initialize registers by zero
// brg.beta == 0 means no sum - just registers write to output
if (brg.alpha == 0) {
// brg.alpha == 0 means initialize registers, 1 means read from input
// brg.beta == 0 means skip postwork, 1 means do postwork
if (brg.alpha == 0 && brg.beta == 0) {
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
auto vmm = Vmm(m * n_block + n);
Expand All @@ -997,8 +988,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

postamble();

if (brg.alpha != 0 && jcp.with_eltwise)
postops_injector_->prepare_table();
if (postops_injector_) postops_injector_->prepare_table();
}
};

Expand Down

0 comments on commit 1b13037

Please sign in to comment.