Skip to content

Commit

Permalink
x64: fix scales and bias order in brgemm conv post-ops kernel
Browse files Browse the repository at this point in the history
test coverage extended
  • Loading branch information
akharito committed Jan 17, 2023
1 parent 8bb651c commit 27845b8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
24 changes: 12 additions & 12 deletions src/cpu/x64/jit_brgemm_post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,18 +685,6 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

if (req_comp) maybe_apply_comp(m_block, n_block, tail);

if (brg.beta != 0 && jcp.with_bias) {
for (int n = 0; n < n_block; n++) {
auto vmm_bias = vmm_tmp(0);
auto bias_addr = ptr[aux_reg_bias
+ bia_typesize_ * (n * brg.ld_block)];
cvt2ps(bia_dt_, vmm_bias, bias_addr, tail, false, k_mask);
for (int m = 0; m < m_block; m++) {
vaddps(vector(m, n), vmm_bias);
}
}
}

if (brg.beta != 0) {
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
Expand All @@ -714,6 +702,18 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}
}

if (brg.beta != 0 && jcp.with_bias) {
for (int n = 0; n < n_block; n++) {
auto vmm_bias = vmm_tmp(0);
auto bias_addr = ptr[aux_reg_bias
+ bia_typesize_ * (n * brg.ld_block)];
cvt2ps(bia_dt_, vmm_bias, bias_addr, tail, false, k_mask);
for (int m = 0; m < m_block; m++) {
vaddps(vector(m, n), vmm_bias);
}
}
}

if (postops_injector_) inject_attr_postops(m_block, n_block, tail);

if (brg.beta != 0 && brg.zp_type_c != brgemm_broadcast_t::none) {
Expand Down
4 changes: 3 additions & 1 deletion tests/benchdnn/inputs/conv/harness_conv_attrs_int8
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
--mb=2
--skip-impl=ref,x64:gemm # ! test jit version only
--dir=FWD_B
--attr-scales=src:common:0.25*+wei:per_oc:0.5*+dst:common:2.25* --attr-post-ops=sum:1.5:2+relu
--attr-scales=src:common:0.25*+wei:per_oc:0.5*,src:common:0.25*+wei:per_oc:0.5*+dst:common:2.25*
--attr-post-ops=sum:1.5:2+relu
--cfg=s8s8f32,s8s8u8,u8s8f32,u8s8u8 --batch=shapes_tails
--cfg=s8s8u8,u8s8u8 --batch=shapes_basic

# i8 conv + f32 leaky relu
--reset --dir=FWD_B --mb=2
Expand Down

0 comments on commit 27845b8

Please sign in to comment.