Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

x64: brdgmm conv: enable zps per group (backport) #2292

Draft
wants to merge 2 commits into
base: rls-v3.6
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/primitive_attr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ struct zero_points_t : public c_compatible {

// arg-specific checks
bool common(int arg) const { return get_mask(arg) == 0; }
bool per_dim_1(int arg) const { return get_mask(arg) == 2; }
bool defined(int arg) const { return has_default_values(arg); }
bool has_default_values(int arg) const {
return is_set(arg) == false && has_default_data_type(arg);
Expand Down
26 changes: 19 additions & 7 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations;
brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
brgemm_p.a_zp_values = post_ops_data.a_zp_values;
brgemm_p.c_zp_values = post_ops_data.c_zp_values;
brgemm_p.ptr_dst_scales = post_ops_data.dst_scales;
if (dynamic_values) {
Expand Down Expand Up @@ -457,19 +458,30 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg,
auto zero_points = attr->zero_points_;

// common zero point type is supported for now
if (!zero_points.common(mem_arg)) return status::unimplemented;
const bool is_per_dim_1_bcast = zero_points.per_dim_1(mem_arg);
const bool is_common_bcast = zero_points.common(mem_arg);
if (!is_common_bcast && !is_per_dim_1_bcast)
return status::unimplemented;

const bool skip_zero_point
= mem_arg == DNNL_ARG_WEIGHTS && brg->skip_zp_b_compensation;
zp_type = zero_points.has_default_values(mem_arg) || skip_zero_point
? brgemm_broadcast_t::none
: brgemm_broadcast_t::per_tensor;

zp_type = brgemm_broadcast_t::none;
const bool is_any_bcast
= !(zero_points.has_default_values(mem_arg) || skip_zero_point);
if (is_any_bcast) {
if (is_common_bcast)
zp_type = brgemm_broadcast_t::per_tensor;
else if (is_per_dim_1_bcast)
zp_type = brgemm_broadcast_t::per_n;
}

return status::success;
};

init_zp_type(brg->zp_type_a, DNNL_ARG_SRC);
init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS);
init_zp_type(brg->zp_type_c, DNNL_ARG_DST);
CHECK(init_zp_type(brg->zp_type_a, DNNL_ARG_SRC));
CHECK(init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS));
CHECK(init_zp_type(brg->zp_type_c, DNNL_ARG_DST));

// Post-ops may use vector registers so brgemm/brdgmm blocking may need to
// be updated
Expand Down
8 changes: 6 additions & 2 deletions src/cpu/x64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ struct brgemm_kernel_params_t {

const void *a_zp_compensations = nullptr;
const void *b_zp_compensations = nullptr;
const void *a_zp_values = nullptr;
const void *c_zp_values = nullptr;
size_t skip_accm = 0;
int32_t zp_a_val = 1;
Expand Down Expand Up @@ -582,7 +583,8 @@ struct brgemm_post_ops_data_t {
const void *b_zp_compensations = nullptr,
const void *c_zp_values = nullptr, bool skip_accumulation = false,
int32_t zp_a_val = 1, bool do_only_comp = false,
bool do_only_zp_a_val = false, const float *dst_scales = nullptr)
bool do_only_zp_a_val = false, const float *dst_scales = nullptr,
const void *a_zp_values = nullptr)
: bias(bias)
, scales(scales)
, binary_post_ops_rhs(binary_post_ops_rhs)
Expand All @@ -597,7 +599,8 @@ struct brgemm_post_ops_data_t {
, zp_a_val {zp_a_val}
, do_only_comp {do_only_comp}
, do_only_zp_a_val {do_only_zp_a_val}
, dst_scales(dst_scales) {}
, dst_scales(dst_scales)
, a_zp_values(a_zp_values) {}

const void *bias = nullptr;
const float *scales = nullptr;
Expand All @@ -614,6 +617,7 @@ struct brgemm_post_ops_data_t {
const bool do_only_comp = false;
const bool do_only_zp_a_val = false;
const float *dst_scales = nullptr;
const void *a_zp_values = nullptr;
};

} // namespace x64
Expand Down
120 changes: 90 additions & 30 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jit_brdgmm_kernel_base_t<Wmm>::jit_brdgmm_kernel_base_t(
, max_vmms_(isa_num_vregs(brg.isa_impl))
, compute_dst_zp_(brg.zp_type_c != brgemm_broadcast_t::none)
, compute_src_zp_(brg.zp_type_a != brgemm_broadcast_t::none)
, is_src_zp_bcast_(brg.zp_type_a == brgemm_broadcast_t::per_tensor)
, compute_compensation_(compute_src_zp_ || brg.req_s8s8_compensation)
, has_vpad_(brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0)
, has_bpad_(brg.brgattr.max_top_bpad > 0 || brg.brgattr.max_bottom_bpad > 0)
Expand Down Expand Up @@ -147,7 +148,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::read_params() {
}

if (compute_src_zp_) {
mov(reg_tmp, ptr[param1 + GET_OFF(zp_a_val)]);
mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_values)]);
mov(ptr[rsp + src_zp_value_], reg_tmp);

mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_compensations)]);
Expand Down Expand Up @@ -609,6 +610,17 @@ void jit_brdgmm_kernel_base_t<Wmm>::maybe_transpose_interleaved_vnni_to_plain(
}
}

template <typename Wmm>
void jit_brdgmm_kernel_base_t<Wmm>::load_src_zp() {
mov(reg_src_zero_point, ptr[rsp + src_zp_value_]);
lea(reg_src_zero_point,
is_src_zp_bcast_
? ptr_b[reg_src_zero_point]
: ptr[reg_src_zero_point + reg_aux_N * sizeof(int32_t)]);
if (!is_superset(brg.isa_impl, avx512_core) && is_src_zp_bcast_)
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
}

template <typename Wmm>
void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
int m_blocks, int n_blocks, bool has_n_tail) {
Expand All @@ -620,12 +632,10 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
lea(reg_s8s8_comp, ptr[reg_s8s8_comp + reg_aux_N * sizeof(int32_t)]);
}
if (compute_src_zp_) {
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
load_src_zp();
mov(reg_zp_compensation, ptr[rsp + zp_compensation_]);
lea(reg_zp_compensation,
ptr[reg_zp_compensation + reg_aux_N * sizeof(int32_t)]);
if (!is_superset(brg.isa_impl, avx512_core))
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
}

for_(int v_i = 0; v_i < v_substep; ++v_i)
Expand All @@ -640,16 +650,35 @@ void jit_brdgmm_kernel_base_t<Wmm>::compute_int8_compensation(
}
if (compute_src_zp_) {
// zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
const Vmm vmm_zp = vmm_zp_comp();
vmovups(vmm_zp,
maybe_EVEX_compress_addr(reg_zp_compensation, offset));
if (is_superset(brg.isa_impl, avx512_core)) {
const bool src_zp_is_common = true;
vpmulld(vmm_zp, vmm_zp,
maybe_EVEX_compress_addr(
reg_src_zero_point, 0, src_zp_is_common));
const bool is_tail
= n + 1 == n_blocks && has_n_tail && substep_simd < simd_w_;
const Vmm vmm_zp = isa_has_masks(brg.isa_impl)
? maybe_mask(vmm_zp_comp(), is_tail, false)
: vmm_zp_comp();
if (IMPLICATION(is_tail, isa_has_masks(brg.isa_impl))) {
vmovups(vmm_zp,
maybe_EVEX_compress_addr(reg_zp_compensation, offset));
if (is_src_zp_bcast_) {
if (is_superset(brg.isa_impl, avx512_core))
vpmulld(vmm_zp, vmm_zp,
maybe_EVEX_compress_addr(
reg_src_zero_point, 0, true));
else
vpmulld(vmm_zp, vmm_zp, vmm_bcast());
} else
vpmulld(vmm_zp, vmm_zp,
maybe_EVEX_compress_addr(
reg_src_zero_point, offset));
} else {
vpmulld(vmm_zp, vmm_zp, vmm_bcast());
const int tail_size = tail_length();
const Vmm ymm_tmp
= vmm_bcast(); // used for bcast or tail processing in avx2
load_data(data_type::s32, vmm_zp,
ptr[reg_zp_compensation + offset], tail_size);
if (!is_src_zp_bcast_)
load_data(data_type::s32, ymm_tmp,
ptr[reg_src_zero_point + offset], tail_size);
vpmulld(vmm_zp, vmm_zp, ymm_tmp);
}
}
for (int m = 0; m < m_blocks; m++) {
Expand Down Expand Up @@ -795,24 +824,48 @@ void jit_brdgmm_kernel_base_t<Wmm>::load_b(

template <typename Wmm>
void jit_brdgmm_kernel_base_t<Wmm>::comp_dot_product(
compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb) {
compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb, int n,
bool is_tail_block) {
switch (kernel_type) {
case compute_pad_kernel_t::s8s8_kernel:
vpdpbusd(vmm_acc, vmm_shift(), vmmb,
is_superset(brg.isa_impl, avx512_core)
? Xbyak::EvexEncoding
: Xbyak::VexEncoding);
break;
case compute_pad_kernel_t::zero_point_kernel:
if (is_superset(brg.isa_impl, avx512_core)) {
vpmulld(vmm_zp_comp(), vmmb,
maybe_EVEX_compress_addr(reg_src_zero_point, 0, true));
case compute_pad_kernel_t::zero_point_kernel: {
const Vmm vmm_zp = isa_has_masks(brg.isa_impl)
? maybe_mask(vmm_zp_comp(), is_tail_block, false)
: vmm_zp_comp();
const size_t offset = comp_offset(n);
if (IMPLICATION(is_tail_block, isa_has_masks(brg.isa_impl))) {
if (is_src_zp_bcast_) {
if (is_superset(brg.isa_impl, avx512_core))
vpmulld(vmm_zp, vmmb,
maybe_EVEX_compress_addr(
reg_src_zero_point, 0, true));
else
vpmulld(vmm_zp, vmmb, vmm_bcast());
} else {
const Xbyak::Address src_zp_addr = maybe_EVEX_compress_addr(
reg_src_zero_point, offset);
if (is_fast_vnni_int8()) {
vmovups(vmm_zp, src_zp_addr);
vpermd(vmm_zp, vmm_permute(), vmm_zp);
vpmulld(vmm_zp, vmmb, vmm_zp);
} else
vpmulld(vmm_zp, vmmb, src_zp_addr);
}
} else {
uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]);
vpmulld(vmm_zp_comp(), vmmb, vmm_bcast());
const Vmm ymm_tmp
= vmm_bcast(); // used for bcast or tail processing in avx2
if (!is_src_zp_bcast_)
load_data(data_type::s32, ymm_tmp,
ptr[reg_src_zero_point + offset], tail_length());
vpmulld(vmm_zp, vmmb, ymm_tmp);
}
vpaddd(vmm_acc, vmm_acc, vmm_zp_comp());
break;
} break;
default: assert(!"unsupported comp_kernel type");
}
}
Expand Down Expand Up @@ -853,21 +906,25 @@ void jit_brdgmm_kernel_base_t<Wmm>::pad_comp_kernel(

for (int pad_i = max_m_unroll; pad_i > 0; --pad_i) {
L(jmp_table_labels[pad_i]);
if (is_zero_point_kernel)
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
if (is_zero_point_kernel) load_src_zp();
if (pad_i > m_blocks) continue;
const int m_i = get_mi(pad_i);
int p_b_i = 0;
for (int n_i = 0; n_i < n_blocks; ++n_i, ++p_b_i) {
if (get_substep_simd(n_i, 0, has_tail) <= 0) continue;
const int substep_simd = get_substep_simd(n_i, 0, has_tail);
if (substep_simd <= 0) continue;
const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0);
const bool is_tail_block
= n_i + 1 == n_blocks && has_tail && substep_simd < simd_w_;
if (p_b_i < n_preload_b_vmms) {
comp_dot_product(kernel_type, vmm_acc, vmm_b(p_b_i));
comp_dot_product(
kernel_type, vmm_acc, vmm_b(p_b_i), n_i, is_tail_block);
} else {
// preloaded vmm_b not available
const Vmm vmm_wei = vmm_b(max_bvmms - 1);
load_b(vmm_wei, n_i, 0, has_tail, load_broadcast_wei);
comp_dot_product(kernel_type, vmm_acc, vmm_wei);
comp_dot_product(
kernel_type, vmm_acc, vmm_wei, n_i, is_tail_block);
}
}
}
Expand All @@ -885,8 +942,7 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
auto kernel_body = [&](compute_pad_kernel_t kernel_type) {
const bool is_zero_point_kernel
= kernel_type == compute_pad_kernel_t::zero_point_kernel;
if (is_zero_point_kernel)
lea(reg_src_zero_point, ptr[rsp + src_zp_value_]);
if (is_zero_point_kernel) load_src_zp();
for (int nb_i = 0; nb_i < n_blocks; nb_i += max_bvmms) {
const int n_e = nstl::min(nb_i + max_bvmms, n_blocks) - nb_i;
for (int i = 0; i < n_e; ++i) {
Expand All @@ -898,9 +954,13 @@ void jit_brdgmm_kernel_base_t<Wmm>::batch_pad_kernel(
for_(int m_i = 0; m_i < m_blocks; ++m_i)
for (int i = 0; i < n_e; ++i) {
const int n_i = nb_i + i;
if (get_substep_simd(n_i, 0, has_tail) <= 0) continue;
const int substep_simd = get_substep_simd(n_i, 0, has_tail);
if (substep_simd <= 0) continue;
const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0);
comp_dot_product(kernel_type, vmm_acc, vmm_b(i));
const bool is_tail_block
= n_i + 1 == n_e && has_tail && substep_simd < simd_w_;
comp_dot_product(
kernel_type, vmm_acc, vmm_b(i), n_i, is_tail_block);
}
}
};
Expand Down
5 changes: 4 additions & 1 deletion src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
const int simd_w_;
const int max_vmms_;
const bool compute_dst_zp_, compute_src_zp_;
const bool is_src_zp_bcast_;
const bool compute_compensation_; // code-path for either s8s8 or src_zp
const bool has_vpad_; // vertical padding w.r.t. M dimension
const bool has_bpad_; // batch pad is computed for the overlap between the
Expand Down Expand Up @@ -341,7 +342,8 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
void load_b(
Vmm vmmb, int n_i, int v_i, bool has_n_tail, bool wei_zp = false);
void comp_dot_product(compute_pad_kernel_t kernel_type, Vmm vmm_acc,
Vmm vmmb); // int8 compensation dot_product (zp and s8s8)
Vmm vmmb, int n,
bool is_tail_block); // int8 compensation dot_product (zp and s8s8)
void pad_comp_kernel(compute_pad_kernel_t kernel_type, int m_blocks,
int n_blocks, int padding, const Xbyak::Reg64 reg_pad,
const std::function<int(int)> &get_mi, bool has_tail = false);
Expand All @@ -360,6 +362,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
void apply_post_ops(int m_blocks, int n_blocks, bool has_n_tail);
void maybe_transpose_interleaved_vnni_to_plain(
int m_blocks, int n_blocks, bool has_n_tail);
void load_src_zp();
void compute_int8_compensation(int m_blocks, int n_blocks, bool has_n_tail);
void store_accumulators(int m_blocks, int n_blocks, bool has_n_tail);
void store_accumulators_without_post_ops(
Expand Down
12 changes: 8 additions & 4 deletions src/cpu/x64/jit_brdgmm_dw_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
const bool params_ok
= IMPLICATION(has_zero_points, utils::one_of(jcp.src_dt, u8, s8))
&& IMPLICATION(jcp.src_zero_point,
attr()->zero_points_.common(DNNL_ARG_SRC))
attr()->zero_points_.common(DNNL_ARG_SRC)
|| attr()->zero_points_.per_dim_1(DNNL_ARG_SRC))
&& IMPLICATION(jcp.dst_zero_point,
attr()->zero_points_.common(DNNL_ARG_DST));
VDISPATCH_CONV(params_ok, VERBOSE_UNSUPPORTED_ZP_CFG);
Expand Down Expand Up @@ -583,7 +584,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);

DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);

const int wei_scale_mask
Expand Down Expand Up @@ -753,8 +754,11 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
post_ops_data.scales = &oscales[jcp.is_oc_scale * ch];
post_ops_data.oc_logical_off = ch;
post_ops_data.dst_scales = dst_scales;
post_ops_data.zp_a_val
= jcp.src_zero_point ? src_zero_point : 1;
const bool is_bcast_zp
= pd()->attr()->zero_points_.common(DNNL_ARG_SRC);
post_ops_data.a_zp_values = jcp.src_zero_point
? src_zero_point + ch * !is_bcast_zp
: nullptr;
post_ops_data.c_zp_values
= jcp.dst_zero_point ? dst_zero_point : nullptr;
post_ops_data.a_zp_compensations
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/inputs/conv/test_conv_ci
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
--attr-zero-points=
--batch=shapes_basic
--attr-post-ops=
--attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1
--attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1,src:per_dim_1+dst:common:1
--batch=shapes_basic
### Signed input
--dt=s8:s8:s8
Expand All @@ -77,7 +77,7 @@
--attr-zero-points=
--batch=shapes_basic
--attr-post-ops=
--attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1
--attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1,src:per_dim_1+dst:common:1
--batch=shapes_basic
# BF32
--reset
Expand Down
Loading