Skip to content

Commit

Permalink
x64: support dst scales in brgemm-based implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
akharito authored and tprimak committed Feb 23, 2023
1 parent 18de927 commit 38319f1
Show file tree
Hide file tree
Showing 17 changed files with 119 additions and 47 deletions.
7 changes: 5 additions & 2 deletions src/cpu/x64/jit_brdgmm_dw_conv.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -195,7 +195,8 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
|| !wei_scales.has_default_values();
jcp.is_oc_scale = wei_scales.mask_ != 0;

const bool scales_ok = attr_scales_ok({DNNL_ARG_SRC, DNNL_ARG_WEIGHTS});
const bool scales_ok
= attr_scales_ok({DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST});
if (!scales_ok) return status::unimplemented;

// strd is only feasible for 1D (i.e., height dim is one)
Expand Down Expand Up @@ -388,6 +389,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {

DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);

const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
src_scales, wei_scales, pd()->OC(), pd()->attr());
Expand Down Expand Up @@ -516,6 +518,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
post_ops_data.bias = bias + ch * jcp.bia_dsz;
post_ops_data.scales = &oscales[jcp.is_oc_scale * ch];
post_ops_data.oc_logical_off = ch;
post_ops_data.dst_scales = dst_scales;
brgemm_kernel_execute_postops(kernel, bs, ptr_A, ptr_B,
brg_batch, ptr_C, ptr_C, post_ops_data,
nullptr /*scratch*/);
Expand Down
12 changes: 8 additions & 4 deletions src/cpu/x64/jit_brgemm_1x1_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
int od, int oh, int ow, int icc, int *last_palette_idx,
const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp,
int32_t *dst_zp_vals, int32_t *s8s8_compensation) const {
int32_t *dst_zp_vals, int32_t *s8s8_compensation,
const float *dst_scales) const {

const memory_desc_wrapper src_d(pd()->src_md());
const memory_desc_wrapper weights_d(pd()->weights_md());
Expand Down Expand Up @@ -428,7 +429,8 @@ void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
post_ops_binary_rhs_arg_vec.data(),
static_cast<size_t>(g_oc), 0, dst, 0,
static_cast<void *>(src_zp_comp_ptr), nullptr,
static_cast<void *>(dst_zp_vals), false, src_zp_vals};
static_cast<void *>(dst_zp_vals), false, src_zp_vals, false,
false, dst_scales};

void *scratch = is_amx ? static_cast<void *>(wsp_tile)
: static_cast<void *>(s8s8_comp_ptr);
Expand Down Expand Up @@ -473,6 +475,7 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(

DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);

const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
src_scales, wei_scales, pd()->OC(), pd()->attr());
Expand Down Expand Up @@ -553,7 +556,8 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, \
inp_buffer_sp, g, n, ocb, od, oh, ow, icc, \
&last_palette_idx, oscales, src_zero_point, \
zp_compensation, dst_zp_vals, s8s8_compensation); \
zp_compensation, dst_zp_vals, s8s8_compensation, \
dst_scales); \
} \
} \
last_n = n; \
Expand Down Expand Up @@ -595,7 +599,7 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, nullptr, g, n, \
ocb, od, oh, ow, icc, &last_palette_idx, oscales, \
src_zero_point, zp_compensation, dst_zp_vals, \
s8s8_compensation); \
s8s8_compensation, dst_scales); \
} \
nd_iterator_step(__VA_ARGS__); \
} \
Expand Down
8 changes: 5 additions & 3 deletions src/cpu/x64/jit_brgemm_1x1_conv.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,7 +64,8 @@ struct brgemm_1x1_convolution_fwd_t : public primitive_t {

protected:
bool arg_scales_ok() const {
std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
std::vector<int> supported_args
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
return attr_scales_ok(supported_args);
}
bool zero_points_ok() const {
Expand Down Expand Up @@ -121,7 +122,8 @@ struct brgemm_1x1_convolution_fwd_t : public primitive_t {
char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
int od, int oh, int ow, int icc, int *last_brg_idx,
const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp,
int32_t *dst_zp_vals, int32_t *s8s8_compensation) const;
int32_t *dst_zp_vals, int32_t *s8s8_compensation,
const float *dst_scales) const;
status_t execute_forward_all(const exec_ctx_t &ctx) const;
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }

Expand Down
16 changes: 10 additions & 6 deletions src/cpu/x64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
int32_t *src_zp_comp_ptr;
int32_t *dst_zp_vals;
int32_t *s8s8_comp_ptr;
const float *dst_scales {nullptr};
};

template <cpu_isa_t isa, bool use_inversion>
Expand All @@ -884,6 +885,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(

DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);

const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
src_scales, wei_scales, _pd->OC(), _pd->attr());
Expand Down Expand Up @@ -1012,6 +1014,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
btc.src_zp_comp_ptr
= jcp.src_zero_point ? src_zp_comp_base : nullptr;
btc.s8s8_comp_ptr = jcp.s8s8_avx512 ? s8s8_comp_base : nullptr;
btc.dst_scales = dst_scales;

if (jcp.exec_type == exec_trans && (last_n != n || last_g != g)) {
if (!jcp.copy_block_only)
Expand Down Expand Up @@ -1133,7 +1136,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
int kd_l, int kh_l, const void *post_ops_binary_rhs_arg_vec,
const float *oscales, int32_t src_zp_vals, int32_t *src_zp_ptr,
int32_t *dst_zp_ptr, int32_t *s8s8_compensation, bool maybe_do_init,
bool do_postwork, bool do_post_comp) const {
bool do_postwork, bool do_post_comp, const float *dst_scales) const {

const auto _pd = pd();
const auto &jcp = _pd->jcp_;
Expand All @@ -1160,6 +1163,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
p.dst_orig = dst;
p.c_zp_values = dst_zp_ptr;
p.a_comp_val = src_zp_vals;
p.ptr_dst_scales = (void *)dst_scales;
}

auto call_outwork_ker = [&](bool is_postwork, bool has_postcomp,
Expand Down Expand Up @@ -1246,7 +1250,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
static_cast<size_t>(g_oc), 0, btc.brgemm_ctx.dst, 0,
static_cast<void *>(src_zp_ptr), nullptr,
static_cast<void *>(dst_zp_ptr), false, src_zp_vals,
do_only_comp, do_only_pass_comp};
do_only_comp, do_only_pass_comp, btc.dst_scales};

void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile)
: static_cast<void *>(s8s8_comp);
Expand Down Expand Up @@ -1581,7 +1585,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
g_oc, is_oc_tail, ow_b, ow_e, kd_l, kh_l,
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
btc.s8s8_comp_ptr, do_init, do_postwork, false);
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
};

if (kd_f > kd_s && kh_f > kh_s && kw_f > kw_s) {
Expand Down Expand Up @@ -1636,7 +1640,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
btc.s8s8_comp_ptr, do_init, do_postwork, false);
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
}
}

Expand Down Expand Up @@ -1792,7 +1796,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
btc.s8s8_comp_ptr, do_init, do_postwork, false);
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
}
}

Expand Down Expand Up @@ -1944,7 +1948,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(
g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
post_ops_binary_rhs_arg_vec.data(), btc.oscales,
btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
btc.s8s8_comp_ptr, do_init, do_postwork, false);
btc.s8s8_comp_ptr, do_init, do_postwork, false, btc.dst_scales);
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/cpu/x64/jit_brgemm_conv.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -105,7 +105,8 @@ struct brgemm_convolution_fwd_t : public primitive_t {

protected:
bool arg_scales_ok() const {
std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
std::vector<int> supported_args
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
return attr_scales_ok(supported_args);
}

Expand Down Expand Up @@ -181,7 +182,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {
const void *post_ops_binary_rhs_arg_vec, const float *oscales,
int32_t src_zp_vals, int32_t *src_zp_ptr, int32_t *dst_zp_ptr,
int32_t *s8s8_compensation, bool maybe_do_init, bool do_postwork,
bool do_post_comp) const;
bool do_post_comp, const float *dst_scales) const;

void call_brgemm_kernel(brgemm_thread_ctx_t &btc, int brg_idx,
int batch_size, char *ptr_C, char *ptr_D, const char *bias_w,
Expand Down
14 changes: 10 additions & 4 deletions src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2022-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
#include "common/type_helpers.hpp"
#include "common/utils.hpp"
#include "cpu/cpu_primitive.hpp"
#include "cpu/scale_utils.hpp"

#include "cpu/x64/jit_brgemm_conv_bwd_strided.hpp"
#include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp"
Expand Down Expand Up @@ -369,8 +370,12 @@ status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::execute(
const auto _pd = pd();
const auto &jcp = _pd->jcp_;

// XXX: brgemm requires scales to be passed, so passing default wei scales
DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);

const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
src_scales, wei_scales, _pd->IC(), _pd->attr());

const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
brgemm_batch_element_t *const __restrict brg_batch_global
Expand Down Expand Up @@ -451,6 +456,7 @@ status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::execute(
btc.ihb = ihb;
btc.iwb = iwb;
btc.oscales = oscales;
btc.dst_scales = dst_scales;

auto id_begin = idb * jcp.id_block;
auto id_end = nstl::min(ID, id_begin + jcp.id_block);
Expand Down Expand Up @@ -526,7 +532,7 @@ void brgemm_convolution_bwd_strided_t<isa, enable_postops>::call_brgemm_kernel(
static_cast<size_t>(g_ic), 0, btc.brgemm_ctx.dst, 0,
static_cast<void *>(src_zp_ptr), nullptr,
static_cast<void *>(dst_zp_ptr), do_skip_accm, src_zp_vals,
do_only_comp, do_only_pass_comp};
do_only_comp, do_only_pass_comp, btc.dst_scales};

void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile)
: static_cast<void *>(s8s8_comp);
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/x64/jit_brgemm_conv_bwd_strided.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2022-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -145,6 +145,7 @@ struct brgemm_convolution_bwd_strided_t : public primitive_t {
int occ;
int sw;
const float *oscales {nullptr};
const float *dst_scales {nullptr};
};

void ker_trans(brgemm_bwd_thread_ctx_t &btc, char *inp_buffer) const;
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
16 changes: 11 additions & 5 deletions src/cpu/x64/jit_brgemm_inner_product.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -78,6 +78,7 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(

DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);

const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
src_scales, wei_scales, pd()->OC(), pd()->attr());
Expand Down Expand Up @@ -108,7 +109,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(

const bool are_post_ops_applicable = one_of(true, jbgp.with_sum,
jbgp.with_bias, jbgp.with_scales, jbgp.with_eltwise,
jbgp.with_binary, jbgp.acc_dt != jbgp.dst_dt, jbgp.signed_input);
jbgp.with_binary, jbgp.acc_dt != jbgp.dst_dt, jbgp.signed_input,
jbgp.with_dst_scales);

size_t offset = types::data_type_size(jbgp.wei_dt)
* (weights_d.size() - weights_d.additional_buffer_size());
Expand Down Expand Up @@ -221,7 +223,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
static_cast<const void *>(ptr_bias),
&oscales[jbgp.is_oc_scale * oc],
post_ops_binary_rhs_arg_vec.data(),
static_cast<size_t>(oc), 0, dst};
static_cast<size_t>(oc), 0, dst, 0, nullptr, nullptr,
nullptr, false, 1, false, false, dst_scales};

brgemm_kernel_execute_postops(brg_kernel, gemm_batch,
addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data,
Expand Down Expand Up @@ -264,7 +267,8 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
static_cast<const void *>(ptr_bias),
&oscales[jbgp.is_oc_scale * oc],
post_ops_binary_rhs_arg_vec.data(),
static_cast<size_t>(oc), 0, dst};
static_cast<size_t>(oc), 0, dst, 0, nullptr, nullptr,
nullptr, false, 1, false, false, dst_scales};

brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch,
(void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
Expand Down Expand Up @@ -457,7 +461,9 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
&oscales[jbgp.is_oc_scale * oc],
post_ops_binary_rhs_arg_vec.data(),
static_cast<size_t>(oc), 0, dst, 0, nullptr,
nullptr, nullptr, true /* skip_accm */};
nullptr, nullptr, true /* skip_accm */, 1,
false, false, dst_scales};

brgemm_kernel_execute_postops(brg_kernel, 0,
nullptr, (void *)ptr_C, (void *)ptr_D,
post_ops_data, scratch);
Expand Down
7 changes: 4 additions & 3 deletions src/cpu/x64/jit_brgemm_inner_product.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,7 +81,7 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
bool are_post_ops_applicable = one_of(true, jbgp_.with_sum,
jbgp_.with_bias, jbgp_.with_scales, jbgp_.with_eltwise,
jbgp_.with_binary, jbgp_.acc_dt != jbgp_.dst_dt,
jbgp_.signed_input);
jbgp_.signed_input, jbgp_.with_dst_scales);

const float alpha = 1.0;
const float beta = 1.0;
Expand Down Expand Up @@ -142,7 +142,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t {
}

bool arg_scales_ok() const {
std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
std::vector<int> supported_args
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
return attr_scales_ok(supported_args);
}

Expand Down
4 changes: 3 additions & 1 deletion src/cpu/x64/jit_brgemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -299,6 +299,7 @@ status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp,
const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
jbgp.is_oc_scale = wei_scales.mask_ != 0;
}

const int min_ic_divisor = is_amx_int8 ? 4 : is_amx_xf16 ? 2 : 1;

jbgp.use_buffer_a = jbgp.ic % min_ic_divisor != 0;
Expand Down Expand Up @@ -995,6 +996,7 @@ status_t init_ip_conf(cpu_isa_t isa, jit_brgemm_primitive_conf_t &jbgp,
if (is_int8) {
jbgp.acc_dt = s32;
jbgp.with_scales = true;
jbgp.with_dst_scales = true;
} else
jbgp.acc_dt = f32;

Expand Down
Loading

0 comments on commit 38319f1

Please sign in to comment.