Skip to content

Commit

Permalink
x64: bregemm: add support of scales for dst tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
akharito authored and tprimak committed Feb 23, 2023
1 parent 85171b0 commit b6170d1
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 20 deletions.
9 changes: 7 additions & 2 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations;
brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
brgemm_p.c_zp_values = post_ops_data.c_zp_values;
brgemm_p.ptr_dst_scales = post_ops_data.dst_scales;
assert(brg_kernel);
(*brg_kernel)(&brgemm_p);
}
Expand Down Expand Up @@ -140,6 +141,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations;
brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
brgemm_p.c_zp_values = post_ops_data.c_zp_values;
brgemm_p.ptr_dst_scales = post_ops_data.dst_scales;
assert(brg_kernel);
(*brg_kernel)(&brgemm_p);
}
Expand Down Expand Up @@ -330,9 +332,12 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
// that mask has correct value for this case
brg->is_oc_scale = wei_scales.mask_ != 0;
}
const bool scales_ok = src_scales.mask_ == 0

const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST);
brg->with_dst_scales = !dst_scales.has_default_values();
const bool scales_ok = src_scales.mask_ == 0 && dst_scales.mask_ == 0
&& attr->scales_.has_default_values(
{DNNL_ARG_SRC, DNNL_ARG_WEIGHTS});
{DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST});
if (!scales_ok) return status::unimplemented;

auto init_zp_type
Expand Down
24 changes: 20 additions & 4 deletions src/cpu/x64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -252,6 +252,7 @@ struct brgemm_t {
bool interleave_tilestores_ = false;

brgemm_prf_t prfA, prfB, prfC;
bool with_dst_scales = false;

bool is_row_major() const {
assert(layout != brgemm_layout_undef);
Expand Down Expand Up @@ -366,6 +367,7 @@ struct brgemm_kernel_params_t {
const void *c_zp_values = nullptr;
size_t skip_accm = 0;
int32_t zp_a_val = 1;
const void *ptr_dst_scales = nullptr;
};

template <cpu_isa_t isa, typename Vmm>
Expand Down Expand Up @@ -423,7 +425,10 @@ struct brdgmm_kernel_t : public brgemm_kernel_t {
};

/// @param bias Vector of bias (vector length is N)
/// @param scales Vector of scales (vector length is N)
/// @param scales - Vector of scale factor values which represents combination
/// scale factors for matrixes A and B. If brgemm_t::is_oc_scale = true
/// vector length is N otherwise it must be broadcasted to vector of simd
/// width length
/// @param binary_post_ops_rhs - Ptr to table of pointers to tensors used as rhs
/// in binary post-operation { void* binary_op_tensor1, ...,
/// void* binary_op_tensor_n}
Expand All @@ -439,6 +444,15 @@ struct brdgmm_kernel_t : public brgemm_kernel_t {
/// @param skip_accumulation - specifies whether to skip accumulation when
/// computing post-ops. `Beta` value from descriptor affects final
/// accumulator values taken.
/// @param do_only_comp - specifies whether to perform accumulation only and skip
/// post-ops.
/// @param do_only_zp_a_val - specifies to apply pre-calculated compensation for
/// A zero point only and skip the rest post-ops.
/// @param zp_a_val - zero point value for A, required to adjust compensation
/// values if do_only_zp_a_val = true.
/// @param dst_scales - Vector of inverted scale factor values for matix C,
/// common scale vector type only is supported, it must be broadcasted to
/// vector of simd width length.
///
struct brgemm_post_ops_data_t {
brgemm_post_ops_data_t() = default;
Expand All @@ -450,7 +464,7 @@ struct brgemm_post_ops_data_t {
const void *b_zp_compensations = nullptr,
const void *c_zp_values = nullptr, bool skip_accumulation = false,
int32_t zp_a_val = 1, bool do_only_comp = false,
bool do_only_zp_a_val = false)
bool do_only_zp_a_val = false, const float *dst_scales = nullptr)
: bias(bias)
, scales(scales)
, binary_post_ops_rhs(binary_post_ops_rhs)
Expand All @@ -464,7 +478,8 @@ struct brgemm_post_ops_data_t {
, skip_accumulation(skip_accumulation)
, zp_a_val {zp_a_val}
, do_only_comp {do_only_comp}
, do_only_zp_a_val {do_only_zp_a_val} {}
, do_only_zp_a_val {do_only_zp_a_val}
, dst_scales(dst_scales) {}

const void *bias = nullptr;
const float *scales = nullptr;
Expand All @@ -480,6 +495,7 @@ struct brgemm_post_ops_data_t {
int32_t zp_a_val = 1;
const bool do_only_comp = false;
const bool do_only_zp_a_val = false;
const float *dst_scales = nullptr;
};

} // namespace x64
Expand Down
30 changes: 26 additions & 4 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -121,6 +121,11 @@ void jit_brdgmm_kernel_base_t<isa, Wmm>::read_params() {
mov(ptr[rsp + reg_scales_offs_], reg_tmp);
}

if (brg.with_dst_scales) {
mov(reg_tmp, ptr[param1 + GET_OFF(ptr_dst_scales)]);
mov(ptr[rsp + reg_dst_scales_offs_], reg_tmp);
}

if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1);
}

Expand Down Expand Up @@ -373,6 +378,23 @@ void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators_apply_post_ops(

if (postops_injector_) apply_post_ops(m_blocks, n_blocks, has_n_tail);

if (brg.with_dst_scales) {
mov(reg_aux_dst_scales, ptr[rsp + reg_dst_scales_offs_]);
auto vmm_dst_scales = vmm_tmp(0);
vbroadcastss(vmm_dst_scales, ptr[reg_aux_dst_scales]);

for_(int m = 0; m < m_blocks; m++)
for_(int n = 0; n < n_blocks; n++)
for (int v_i = 0; v_i < v_substep; ++v_i) {
const int substep_simd = get_substep_simd(n, v_i, has_n_tail);
if (substep_simd <= 0) continue;
const bool mask_flag = substep_simd < simd_w_;
const Vmm vmm = maybe_mask(
accm(m_blocks, n_blocks, m, n, v_i), mask_flag, false);
vmulps(vmm, vmm, ptr_b[reg_aux_dst_scales]);
}
}

const bool dt_requires_saturation
= one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32);
auto vmm_lbound = vmm_tmp(0);
Expand Down Expand Up @@ -517,9 +539,9 @@ void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators(
}
}

const bool are_post_ops_applicable
= one_of(true, brg.with_eltwise, brg.with_binary, brg.with_scales,
brg.with_bias, brg.with_sum, brg.dt_d != brg.dt_c);
const bool are_post_ops_applicable = one_of(true, brg.with_eltwise,
brg.with_binary, brg.with_scales, brg.with_bias, brg.with_sum,
brg.dt_d != brg.dt_c, brg.with_dst_scales);

Label label_done;
if (are_post_ops_applicable) {
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -89,6 +89,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
const reg64_t reg_total_padding = reg_table_base;
const reg64_t reg_aux_bias = reg_table_base;
const reg64_t reg_aux_scales = reg_table_base;
const reg64_t reg_aux_dst_scales = reg_table_base;
const reg64_t reg_binary_params = abi_param1; // default for binary ops
const reg64_t reg_ptr_sum_scale = reg_aux_A_vpad_top;
const reg64_t reg_ptr_sum_zp = reg_aux_A_vpad_bottom;
Expand Down Expand Up @@ -116,7 +117,8 @@ struct jit_brdgmm_kernel_base_t : public jit_generator {
constexpr static int reg_A_offs_ = 24; // brgemm_strd
constexpr static int reg_B_offs_ = 32; // brgemm_strd
constexpr static int abi_param1_offs_ = 40;
constexpr static int stack_space_needed_ = 48;
constexpr static int reg_dst_scales_offs_ = 48;
constexpr static int stack_space_needed_ = 56;

bool with_binary_non_scalar_bcast_ = false;

Expand Down
32 changes: 27 additions & 5 deletions src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -128,6 +128,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator {
const reg64_t reg_BS_loop = r9;
const reg64_t reg_bias = rbx;
const reg64_t reg_scales = rbx;
const reg64_t reg_dst_scales = rbx;

const reg64_t reg_stride_ld_block = rdx;
const reg64_t reg_do_post_ops = rbx;
Expand Down Expand Up @@ -1146,18 +1147,39 @@ void jit_brgemm_amx_uker_base_t::process_output_range(brgemm_iteration_t &bi,
if (postops_injector_) {
apply_post_ops_to_range(bi, bd_start, bd_finish, bd_inp_bdb, ldb);
}

if (brg.with_dst_scales) {
mov(reg_dst_scales, ptr[param1 + GET_OFF(ptr_dst_scales)]);
auto zmm_dst_scales = zmm_tmp_1();
vbroadcastss(zmm_dst_scales, ptr[reg_dst_scales]);
for (int bd = bd_start; bd < bd_finish; bd++) {
const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
if (bd_out_bd == -1) continue;

auto zmm = accm(bd);
vmulps(zmm, zmm, zmm_dst_scales);
}
}

if (brg.zp_type_c != brgemm_broadcast_t::none) {
for (int bd = bd_start; bd < bd_finish; bd++) {
const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
if (bd_out_bd == -1) continue;

auto zmm = accm(bd);
vaddps(zmm, zmm, zmm_zp_c);
}
}
}

void jit_brgemm_amx_uker_base_t::store_vector_with_post_ops(const int idx,
const Address &addr, const int bd, const int ldb, bool is_ld_tail) {
auto zmm = Zmm(idx);
auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;

if (brg.zp_type_c != brgemm_broadcast_t::none) vaddps(zmm, zmm, zmm_zp_c);

maybe_saturation(zmm);

auto ymm = Xbyak::Ymm(idx);
auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
const Xbyak::Zmm r_zmm = zmm_mask(zmm, true, true, k_mask);
const Xbyak::Ymm r_ymm = ymm_mask(ymm, true, true, k_mask);

Expand Down Expand Up @@ -2037,7 +2059,7 @@ void jit_brgemm_amx_uker_base_t::generate() {
brg.zp_type_a, brg.zp_type_b, brg.zp_type_c);
are_post_ops_applicable_ = one_of(true, brg.with_eltwise, brg.with_binary,
brg.with_scales, brg.with_bias, brg.with_sum, brg.dt_d != brg.dt_c,
has_zero_points);
has_zero_points, brg.with_dst_scales);

// second level blocking eligible only if we don't use store by vectors for now
assert(IMPLICATION(are_post_ops_applicable_ || need_to_apply_alpha_beta_
Expand Down
29 changes: 26 additions & 3 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
const reg64_t reg_bias = reg_rdb_loop;
const reg64_t reg_scales = reg_rdb_loop;
const reg64_t reg_aux_bias = reg_rdb_loop;
const reg64_t reg_dst_scales = reg_rdb_loop;
const reg64_t reg_binary_postops_oc_l = reg_rdb_loop;
const reg64_t reg_aux_binary_postops_oc_l = reg_rdb_loop;
const reg64_t reg_aux_binary_postops_sp = reg_rdb_loop;
Expand All @@ -173,6 +174,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
const reg64_t reg_aux_zp_c_values = reg_rdb_loop;

const reg64_t reg_aux_scales = reg_aux_B;
const reg64_t reg_aux_dst_scales = reg_aux_B;
const reg64_t reg_do_post_ops = reg_rdb_loop;
const reg64_t reg_do_comp = reg_rdb_loop;
const reg64_t reg_skip_accm = reg_rdb_loop;
Expand Down Expand Up @@ -220,7 +222,8 @@ struct jit_brgemm_kernel_t : public jit_generator {
constexpr static int reg_skip_accm_offs_ = 192;
constexpr static int reg_zp_a_val_offs_ = 200;
constexpr static int reg_do_comp_offs_ = 208;
constexpr static int stack_space_needed_ = 216;
constexpr static int reg_dst_scales_offs_ = 216;
constexpr static int stack_space_needed_ = 224;

bool is_ldb_loop_ = false;
bool handle_binary_po_offset_ = false;
Expand Down Expand Up @@ -907,6 +910,11 @@ void jit_brgemm_kernel_t<isa, Wmm>::read_params() {
mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values);
}

if (brg.with_dst_scales) {
mov(reg_dst_scales, ptr[param1 + GET_OFF(ptr_dst_scales)]);
mov(ptr[rsp + reg_dst_scales_offs_], reg_dst_scales);
}

mov(reg_do_post_ops, ptr[param1 + GET_OFF(do_post_ops)]);
mov(ptr[rsp + reg_do_post_ops_offs_], reg_do_post_ops);

Expand Down Expand Up @@ -1142,6 +1150,19 @@ void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators_apply_post_ops(
if (postops_injector_)
apply_post_ops(bd_block, ld_block2, ldb_and_bdb_offset, is_ld_tail);

if (brg.with_dst_scales) {
mov(reg_aux_dst_scales, ptr[rsp + reg_dst_scales_offs_]);
auto vmm_dst_scales = vmm_tmp_1();
vbroadcastss(vmm_dst_scales, ptr[reg_aux_dst_scales]);

for (int ld = 0; ld < ld_block2; ld++) {
for (int bd = 0; bd < bd_block; bd++) {
auto vmm = accm(ld_block2, bd, ld);
vmulps(vmm, vmm, vmm_dst_scales);
}
}
}

if (brg.zp_type_c != brgemm_broadcast_t::none) {
mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
auto vmm_zp_c = vmm_tmp_1();
Expand Down Expand Up @@ -1348,7 +1369,8 @@ void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators(int bd_block2,
brg.zp_type_a, brg.zp_type_b, brg.zp_type_c);
const bool are_post_ops_applicable = one_of(true, brg.with_eltwise,
brg.with_binary, brg.with_scales, brg.with_bias, brg.with_sum,
brg.dt_d != brg.dt_c, brg.req_s8s8_compensation, has_zero_points);
brg.dt_d != brg.dt_c, brg.req_s8s8_compensation, has_zero_points,
brg.with_dst_scales);
const bool need_to_apply_alpha_beta = brg.beta != 0.f || brg.alpha != 1.f;

if (brg.is_tmm) {
Expand Down Expand Up @@ -1420,7 +1442,8 @@ void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators(int bd_block2,
post_processed |= utils::one_of(true, brg.with_bias,
brg.with_scales, with_binary_per_oc_bcast_,
brg.zp_type_a != brgemm_broadcast_t::none,
brg.zp_type_c == brgemm_broadcast_t::per_n);
brg.zp_type_c == brgemm_broadcast_t::per_n,
brg.with_dst_scales);
}
if (bdb < bd_block2 - 1) {
advance_bdb_post_op_regs(adj_bd_block);
Expand Down

0 comments on commit b6170d1

Please sign in to comment.