From bf3eabdde260c6075f8fbc25221729cf524cff9c Mon Sep 17 00:00:00 2001 From: dmitrygo Date: Thu, 24 Sep 2020 23:56:52 +0300 Subject: [PATCH] [FORK][FEATURE] Introduced Depthwise and Quantization post ops Primitives that supports new post ops: - Jit Convolutions (FP32,BF16,INT8) - Jit Deconvolution (INT8) - Jit,Ref Pooling (INT8) AMX primitives: explicilty pass dst_type into has_default_values checks Extended int8 AMX convolutions to support depthwise/quantization post ops ONEDNN 3.2 migration squashed commits: - Correct assert for reg id in dw conv kernel ONEDNN 3.5 squash list: [FIX] fix avx2 int8 binary postops reg conflict --- include/oneapi/dnnl/dnnl.h | 10 + include/oneapi/dnnl/dnnl.hpp | 22 ++ include/oneapi/dnnl/dnnl_types.h | 10 + src/common/c_types_map.hpp | 6 + src/common/dnnl_debug_autogenerated.cpp | 6 + src/common/ittnotify.cpp | 2 + src/common/math_utils.hpp | 37 +++ src/common/nstl.hpp | 4 + src/common/primitive_attr.cpp | 97 +++++++ src/common/primitive_attr.hpp | 109 +++++++ src/common/primitive_hashing.cpp | 14 + src/common/verbose.cpp | 8 + src/cpu/cpu_pooling_list.cpp | 16 +- src/cpu/gemm_convolution_utils.hpp | 21 ++ src/cpu/ref_pooling.cpp | 26 +- src/cpu/ref_pooling.hpp | 39 ++- src/cpu/x64/gemm_bf16_convolution.hpp | 2 +- .../injectors/jit_uni_depthwise_injector.cpp | 265 ++++++++++++++++++ .../injectors/jit_uni_depthwise_injector.hpp | 135 +++++++++ .../injectors/jit_uni_postops_injector.cpp | 107 ++++++- .../injectors/jit_uni_postops_injector.hpp | 32 ++- .../jit_uni_quantization_injector.cpp | 248 ++++++++++++++++ .../jit_uni_quantization_injector.hpp | 111 ++++++++ src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp | 30 +- src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp | 6 + src/cpu/x64/jit_avx2_1x1_convolution.cpp | 4 + src/cpu/x64/jit_avx2_1x1_convolution.hpp | 7 +- src/cpu/x64/jit_avx2_conv_kernel_f32.cpp | 26 +- src/cpu/x64/jit_avx2_conv_kernel_f32.hpp | 6 + src/cpu/x64/jit_avx2_convolution.cpp | 3 +- src/cpu/x64/jit_avx2_convolution.hpp | 3 + .../x64/jit_avx512_common_1x1_conv_kernel.cpp | 28 +- .../x64/jit_avx512_common_1x1_conv_kernel.hpp | 9 +- .../x64/jit_avx512_common_1x1_convolution.cpp | 3 + .../x64/jit_avx512_common_1x1_convolution.hpp | 5 +- src/cpu/x64/jit_avx512_common_conv_kernel.cpp | 25 +- src/cpu/x64/jit_avx512_common_conv_kernel.hpp | 6 + src/cpu/x64/jit_avx512_common_convolution.cpp | 36 ++- src/cpu/x64/jit_avx512_common_convolution.hpp | 2 + .../jit_avx512_core_amx_1x1_conv_kernel.cpp | 37 ++- .../jit_avx512_core_amx_1x1_conv_kernel.hpp | 10 +- .../jit_avx512_core_amx_1x1_convolution.cpp | 2 +- .../jit_avx512_core_amx_1x1_convolution.hpp | 2 +- .../x64/jit_avx512_core_amx_conv_kernel.cpp | 41 ++- .../x64/jit_avx512_core_amx_conv_kernel.hpp | 8 +- .../x64/jit_avx512_core_amx_convolution.cpp | 2 + .../x64/jit_avx512_core_amx_convolution.hpp | 2 +- .../jit_avx512_core_bf16_1x1_conv_kernel.cpp | 35 ++- .../jit_avx512_core_bf16_1x1_conv_kernel.hpp | 6 + .../jit_avx512_core_bf16_1x1_convolution.cpp | 3 + .../jit_avx512_core_bf16_1x1_convolution.hpp | 2 +- .../x64/jit_avx512_core_bf16_conv_kernel.cpp | 27 +- .../x64/jit_avx512_core_bf16_conv_kernel.hpp | 8 +- .../x64/jit_avx512_core_bf16_convolution.cpp | 6 + .../jit_avx512_core_bf16_dw_conv_kernel.cpp | 28 +- .../jit_avx512_core_bf16_dw_conv_kernel.hpp | 10 +- ...t_avx512_core_x8s8s32x_1x1_conv_kernel.cpp | 35 ++- ...t_avx512_core_x8s8s32x_1x1_conv_kernel.hpp | 8 + ...t_avx512_core_x8s8s32x_1x1_convolution.cpp | 2 + ...t_avx512_core_x8s8s32x_1x1_convolution.hpp | 2 + .../jit_avx512_core_x8s8s32x_conv_kernel.cpp | 46 ++- .../jit_avx512_core_x8s8s32x_conv_kernel.hpp | 8 +- .../jit_avx512_core_x8s8s32x_convolution.cpp | 7 + .../jit_avx512_core_x8s8s32x_convolution.hpp | 2 + ...jit_avx512_core_x8s8s32x_deconvolution.cpp | 48 +++- ...jit_avx512_core_x8s8s32x_deconvolution.hpp | 10 +- src/cpu/x64/jit_gemm_inner_product_utils.cpp | 2 +- src/cpu/x64/jit_primitive_conf.hpp | 12 + src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp | 30 +- src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp | 9 +- src/cpu/x64/jit_sse41_1x1_convolution.cpp | 3 + src/cpu/x64/jit_sse41_1x1_convolution.hpp | 5 +- src/cpu/x64/jit_sse41_conv_kernel_f32.cpp | 29 +- src/cpu/x64/jit_sse41_conv_kernel_f32.hpp | 7 + src/cpu/x64/jit_sse41_convolution.cpp | 3 +- src/cpu/x64/jit_sse41_convolution.hpp | 3 + src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp | 31 +- src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp | 9 +- src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp | 4 +- src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp | 4 +- src/cpu/x64/jit_uni_dw_convolution.cpp | 2 + src/cpu/x64/jit_uni_dw_convolution.hpp | 5 +- src/cpu/x64/jit_uni_i8i8_pooling.cpp | 91 ++++-- src/cpu/x64/jit_uni_i8i8_pooling.hpp | 4 + .../x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp | 36 ++- .../x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp | 10 +- .../x64/jit_uni_x8s8s32x_1x1_convolution.cpp | 2 + .../x64/jit_uni_x8s8s32x_1x1_convolution.hpp | 3 + src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp | 33 ++- src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp | 6 + src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp | 4 + src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp | 3 + .../x64/jit_uni_x8s8s32x_deconvolution.cpp | 35 ++- .../x64/jit_uni_x8s8s32x_deconvolution.hpp | 7 + 94 files changed, 2078 insertions(+), 237 deletions(-) create mode 100644 src/cpu/x64/injectors/jit_uni_depthwise_injector.cpp create mode 100644 src/cpu/x64/injectors/jit_uni_depthwise_injector.hpp create mode 100644 src/cpu/x64/injectors/jit_uni_quantization_injector.cpp create mode 100644 src/cpu/x64/injectors/jit_uni_quantization_injector.hpp diff --git a/include/oneapi/dnnl/dnnl.h b/include/oneapi/dnnl/dnnl.h index 2c9d2a4cb9f..8f9cef177e0 100644 --- a/include/oneapi/dnnl/dnnl.h +++ b/include/oneapi/dnnl/dnnl.h @@ -795,6 +795,16 @@ dnnl_status_t DNNL_API dnnl_post_ops_append_prelu( dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu( const_dnnl_post_ops_t post_ops, int index, int *mask); +dnnl_status_t DNNL_API dnnl_post_ops_append_depthwise( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, + const float* weights_data, const float* biases_data); + +dnnl_status_t DNNL_API dnnl_post_ops_append_quantization( + dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, + const void* crop_low, const void* crop_high, + const void* input_scale, const void* input_shift, + const void* output_scale, const void* output_shift); + /// @} dnnl_api_attributes /// @} dnnl_api_primitives diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index 9772f0b888b..25dfebd3397 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -496,6 +496,12 @@ enum class algorithm { softmax_accurate = dnnl_softmax_accurate, /// LogSoftmax, numerically stable softmax_log = dnnl_softmax_log, + + depthwise_scale_shift = dnnl_depthwise_scale_shift, + depthwise_prelu = dnnl_depthwise_prelu, + + quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize, + quantization_quantize = dnnl_quantization_quantize, }; /// Converts algorithm kind enum value from C++ API to C API type. @@ -3924,6 +3930,22 @@ struct post_ops : public handle { error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask), "could not get parameters of a binary post-op"); } + + void append_depthwise(algorithm alg, const float* weights_data, + const float* biases_data) { + error::wrap_c_api(dnnl_post_ops_append_depthwise(get(), + convert_to_c(alg), weights_data, biases_data), + "could not append depthwise"); + } + + void append_quantization(algorithm alg, + const void* crop_low, const void* crop_high, + const void* input_scale, const void* input_shift, + const void* output_scale, const void* output_shift) { + error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), crop_low, crop_high, + input_scale, input_shift, output_scale, output_shift), + "could not append quantization"); + } }; /// @cond DO_NOT_DOCUMENT_THIS diff --git a/include/oneapi/dnnl/dnnl_types.h b/include/oneapi/dnnl/dnnl_types.h index a56c1ac02b8..40bd71b3c64 100644 --- a/include/oneapi/dnnl/dnnl_types.h +++ b/include/oneapi/dnnl/dnnl_types.h @@ -1992,6 +1992,10 @@ typedef enum { dnnl_deconvolution, /// An element-wise primitive. dnnl_eltwise, + /// An depthwise-wise primitive. + dnnl_depthwise, + /// A quantization primitive. + dnnl_quantization, /// An LRN primitive. dnnl_lrn, /// A batch normalization primitive. @@ -2176,6 +2180,12 @@ typedef enum { dnnl_softmax_accurate = 0x30000, /// Logsoftmax dnnl_softmax_log, + + dnnl_depthwise_scale_shift = 0x3fff0, + dnnl_depthwise_prelu = 0x3fff1, + + dnnl_quantization_quantize_dequantize = 0x4fff0, + dnnl_quantization_quantize = 0x4fff1, } dnnl_alg_kind_t; /// Flags for normalization primitives. diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index 6fe6ba8deca..eb83d4285ca 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -141,6 +141,10 @@ const alg_kind_t reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum; const alg_kind_t softmax_accurate = dnnl_softmax_accurate; const alg_kind_t softmax_log = dnnl_softmax_log; +const alg_kind_t depthwise_scale_shift = dnnl_depthwise_scale_shift; +const alg_kind_t depthwise_prelu = dnnl_depthwise_prelu; +const alg_kind_t quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize; +const alg_kind_t quantization_quantize = dnnl_quantization_quantize; } // namespace alg_kind using data_type_t = dnnl_data_type_t; @@ -1949,6 +1953,8 @@ const primitive_kind_t reduction = dnnl_reduction; const primitive_kind_t softmax = dnnl_softmax; const primitive_kind_t layer_normalization = dnnl_layer_normalization; const primitive_kind_t group_normalization = dnnl_group_normalization; +const primitive_kind_t depthwise = dnnl_depthwise; +const primitive_kind_t quantization = dnnl_quantization; // Internal only primitive kinds. const primitive_kind_t internal_only_start = (primitive_kind_t)(1 << 12); diff --git a/src/common/dnnl_debug_autogenerated.cpp b/src/common/dnnl_debug_autogenerated.cpp index 38e197c30b8..76a7880f585 100644 --- a/src/common/dnnl_debug_autogenerated.cpp +++ b/src/common/dnnl_debug_autogenerated.cpp @@ -1753,6 +1753,8 @@ const char *dnnl_prim_kind2str(dnnl_primitive_kind_t v) { if (v == dnnl_softmax) return "softmax"; if (v == dnnl_layer_normalization) return "layer_normalization"; if (v == dnnl_group_normalization) return "group_normalization"; + if (v == dnnl_depthwise) return "depthwise"; + if (v == dnnl_quantization) return "quantization"; if (v == dnnl_primitive_kind_max) return "primitive_kind_max"; if (v == dnnl::impl::primitive_kind::sdpa) return "sdpa"; assert(!"unknown prim_kind"); @@ -1830,6 +1832,10 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) { if (v == dnnl_reduction_norm_lp_power_p_sum) return "reduction_norm_lp_power_p_sum"; if (v == dnnl_softmax_accurate) return "softmax_accurate"; if (v == dnnl_softmax_log) return "softmax_log"; + if (v == dnnl_depthwise_scale_shift) return "depthwise_scale_shift"; + if (v == dnnl_depthwise_prelu) return "depthwise_prelu"; + if (v == dnnl_quantization_quantize_dequantize) return "quantization_quantize_dequantize"; + if (v == dnnl_quantization_quantize) return "quantization_quantize"; assert(!"unknown alg_kind"); return "unknown alg_kind"; } diff --git a/src/common/ittnotify.cpp b/src/common/ittnotify.cpp index e9c9dfa8404..65a65cb6e41 100644 --- a/src/common/ittnotify.cpp +++ b/src/common/ittnotify.cpp @@ -80,6 +80,8 @@ void primitive_task_start(primitive_kind_t kind) { CASE(layer_normalization), CASE(group_normalization), CASE(sdpa), + CASE(depthwise), + CASE(quantization), }; #undef CASE int kind_idx = (int)kind; diff --git a/src/common/math_utils.hpp b/src/common/math_utils.hpp index 0c156dff8db..993fe7259e4 100644 --- a/src/common/math_utils.hpp +++ b/src/common/math_utils.hpp @@ -567,6 +567,43 @@ inline float stochastic_round_fwd( return r; } +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) { + if (!bias) return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits
::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::bf16); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +inline float get_sum(char *sum, size_t offset, data_type_t data_type) +{ + if (!sum) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits
::type *)sum)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + } // namespace math } // namespace impl } // namespace dnnl diff --git a/src/common/nstl.hpp b/src/common/nstl.hpp index 45a6d7c49ac..f6887ef38d0 100644 --- a/src/common/nstl.hpp +++ b/src/common/nstl.hpp @@ -339,6 +339,10 @@ class vector : public c_compatible { } void clear() { _impl.clear(); } void push_back(const T &t) { _impl.push_back(t); } + template + void emplace_back(Args&&... args) { + _impl.emplace_back(std::forward(args)...); + } void resize(size_type count) { _impl.resize(count); } void reserve(size_type count) { _impl.reserve(count); } }; diff --git a/src/common/primitive_attr.cpp b/src/common/primitive_attr.cpp index 09007dd968a..75126aded4a 100644 --- a/src/common/primitive_attr.cpp +++ b/src/common/primitive_attr.cpp @@ -73,6 +73,29 @@ status_t scales_t::set(dim_t count, int mask, const float *scales) { return status::success; } + +template +status_t shifts_t::set(int count, int mask, const T *shifts) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + shifts_ = shifts_buf_; + utils::array_set(shifts_, shifts[0], shifts_buf_size); + } else { + shifts_ = (T *)impl::malloc(count_ * sizeof(*shifts_), 64); + if (shifts_ == nullptr) + return status::out_of_memory; + + for (int c = 0; c < count_; ++c) + shifts_[c] = shifts[c]; + } + + return status::success; +} + status_t zero_points_t::get(int arg, int *mask, data_type_t *dt) const { if (mask) *mask = get_mask(arg); if (dt) *dt = get_data_type(arg); @@ -182,6 +205,14 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask, #undef CHECK_ARG } +bool primitive_attr_t::has_asymmetric_quantization() const { + return true + && output_scales_.has_default_values() + && rnn_data_qparams_.has_default_values() + && rnn_weights_qparams_.has_default_values() + && (!input_zero_points_.has_default_values() || !weights_zero_points_.has_default_values()); +} + bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const { using smask_t = skip_mask_t; bool ok = true; @@ -313,6 +344,47 @@ status_t post_ops_t::append_prelu(int mask) { return success; } +status_t post_ops_t::append_depthwise(alg_kind_t alg, const float* weights_data, const float* biases_data) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, depthwise_scale_shift, depthwise_prelu); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::depthwise; + e.depthwise.alg = alg; + e.depthwise.weights_data = weights_data; + e.depthwise.biases_data = biases_data; + + return success; +} + +status_t post_ops_t::append_quantization(alg_kind_t alg, + const void* crop_low, const void* crop_high, + const void* input_scale, const void* input_shift, + const void* output_scale, const void* output_shift) { + using namespace dnnl::impl::alg_kind; + if (len() == post_ops_limit) return out_of_memory; + bool known_alg = one_of(alg, quantization_quantize_dequantize, quantization_quantize); + if (!known_alg) + return invalid_arguments; + + entry_.emplace_back(); + auto &e = entry_.back(); + e.kind = primitive_kind::quantization; + e.quantization.alg = alg; + e.quantization.crop_low_data = reinterpret_cast*>(crop_low); + e.quantization.crop_high_data = reinterpret_cast*>(crop_high); + e.quantization.input_scale_data = reinterpret_cast(input_scale); + e.quantization.input_shift_data = reinterpret_cast*>(input_shift); + e.quantization.output_scale_data = reinterpret_cast(output_scale); + e.quantization.output_shift_data = reinterpret_cast*>(output_shift); + + return success; +} + bool post_ops_t::defined() const { for (int idx = 0; idx < len(); ++idx) { auto kind = entry_[idx].kind; @@ -327,6 +399,10 @@ bool post_ops_t::defined() const { primitive_kind::prelu, primitive_kind::convolution)) { // binary is always defined + } else if (kind == primitive_kind::depthwise) { + // depthwise is always defined + } else if (kind == primitive_kind::quantization) { + // quantization is always defined } else { assert(!"unreachable"); } @@ -787,6 +863,23 @@ status_t dnnl_post_ops_get_params_prelu( return success; } +status_t dnnl_post_ops_append_depthwise(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, + const float* weights_data, const float* biases_data) { + if (post_ops == nullptr) return invalid_arguments; + + return post_ops->append_depthwise(alg, weights_data, biases_data); +} + +status_t dnnl_post_ops_append_quantization(post_ops_t *post_ops, alg_kind_t kind, + const void* crop_low, const void* crop_high, + const void* input_scale, const void* input_shift, + const void* output_scale, const void* output_shift) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_quantization(kind, crop_low, crop_high, input_scale, input_shift, output_scale, output_shift); +} + status_t dnnl_primitive_attr_set_rnn_data_qparams( primitive_attr_t *attr, const float scale, const float shift) { if (attr == nullptr) return invalid_arguments; @@ -854,3 +947,7 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams( return attr->rnn_tparams_.set(mode, ngates, scales, cscale); } + +template struct dnnl::impl::shifts_t; +template struct dnnl::impl::shifts_t; +template struct dnnl::impl::shifts_t; diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index 5e1496978ed..36e54f35c73 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -185,6 +185,53 @@ struct scales_t : public c_compatible { DNNL_DISALLOW_COPY_AND_ASSIGN(scales_t); }; +template +struct shifts_t: public c_compatible { + shifts_t(): count_(1), mask_(0), shifts_(shifts_buf_) + { set(0); } + + shifts_t(const shifts_t &rhs): shifts_t() + { set(rhs.count_, rhs.mask_, rhs.shifts_); } + + ~shifts_t() { cleanup(); } + + shifts_t &operator=(const shifts_t &rhs) { + if (&rhs == this) + return *this; + status_t status = set(rhs.count_, rhs.mask_, rhs.shifts_); + assert(status == status::success); + (void)status; + return *this; + } + + bool has_default_values() const { + for (int c = 0; c < count_; ++c) { + if(shifts_[c] != 0) return false; + } + return true; + } + + status_t set(int count, int mask, const T *zero_points); + status_t set(T single_zero_point) { return this->set(1, 0, &single_zero_point); } + + int count_; + int mask_; + T *shifts_; + +private: + enum { shifts_buf_size = 16 }; + T shifts_buf_[shifts_buf_size]; + + void cleanup() { + if (shifts_ != shifts_buf_ && shifts_ != nullptr) + impl::free(shifts_); + + count_ = 1; + mask_ = 0; + shifts_ = shifts_buf_; + } +}; + struct runtime_scales_t : public c_compatible { // Clang-3.8.1 raises an error for a default initialization of a const // object. Const runtime_scales_t object is used as default_scales. @@ -609,6 +656,22 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { int mask; }; + struct depthwise_t { + dnnl::impl::alg_kind_t alg; + const float* weights_data; + const float* biases_data; + }; + + struct quantization_t { + dnnl::impl::alg_kind_t alg; + const dnnl::impl::shifts_t* crop_low_data; + const dnnl::impl::shifts_t* crop_high_data; + const dnnl::impl::scales_t* input_scale_data; + const dnnl::impl::shifts_t* input_shift_data; + const dnnl::impl::scales_t* output_scale_data; + const dnnl::impl::shifts_t* output_shift_data; + }; + dnnl::impl::primitive_kind_t kind = dnnl::impl::primitive_kind::undefined; union { @@ -617,6 +680,8 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { depthwise_conv_t depthwise_conv; binary_t binary; prelu_t prelu; + depthwise_t depthwise; + quantization_t quantization; }; bool is_eltwise(bool require_scale_one = false) const { @@ -655,6 +720,15 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { } bool is_like_binary() const { return is_binary() || is_prelu(); } + bool is_depthwise() const { + using namespace dnnl::impl; + return kind == primitive_kind::depthwise; + } + + bool is_quantization() const { + using namespace dnnl::impl; + return kind == primitive_kind::quantization; + } dnnl::impl::status_t set_depthwise_scales(const float *scales); @@ -697,6 +771,20 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { case primitive_kind::prelu: ret = prelu.mask == rhs.prelu.mask; break; + case primitive_kind::depthwise: + ret = depthwise.alg == rhs.depthwise.alg + && depthwise.weights_data == rhs.depthwise.weights_data + && depthwise.biases_data == rhs.depthwise.biases_data; + break; + case primitive_kind::quantization: + ret = quantization.alg == rhs.quantization.alg + && quantization.crop_low_data == rhs.quantization.crop_low_data + && quantization.crop_high_data == rhs.quantization.crop_high_data + && quantization.input_scale_data == rhs.quantization.input_scale_data + && quantization.input_shift_data == rhs.quantization.input_shift_data + && quantization.output_scale_data == rhs.quantization.output_scale_data + && quantization.output_shift_data == rhs.quantization.output_shift_data; + break; default: assert(!"unsupported post_op"); } return ret; @@ -721,6 +809,12 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { dnnl::impl::status_t append_binary(dnnl::impl::alg_kind_t alg, const dnnl::impl::memory_desc_t *user_src1_desc); dnnl::impl::status_t append_prelu(int mask); + dnnl::impl::status_t append_depthwise(dnnl::impl::alg_kind_t alg, + const float* weights_data, const float* biases_data); + dnnl::impl::status_t append_quantization(dnnl::impl::alg_kind_t alg, + const void* crop_low, const void* crop_high, + const void* input_scale, const void* input_shift, + const void* output_scale, const void* output_shift); dnnl::impl::status_t prepend_binary(dnnl::impl::alg_kind_t alg, const dnnl::impl::memory_desc_t *user_src1_desc); @@ -743,6 +837,16 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible { return dst_dt; } + int count(dnnl::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len(); + stop = dnnl::impl::nstl::min(stop, len()); + int cnt = 0; + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) cnt++; + return cnt; + } + bool defined() const; int len() const { return (int)entry_.size(); } bool has_default_values( @@ -882,6 +986,8 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { bool has_default_values(skip_mask_t mask = skip_mask_t::none, dnnl::impl::data_type_t dst_dt = dnnl_data_type_undef) const; + bool has_asymmetric_quantization() const; + /** Returns true if the attributes are fully defined. */ bool defined(skip_mask_t mask = skip_mask_t::none) const; @@ -980,6 +1086,9 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { dnnl::impl::rnd_mode_t rounding_mode_; std::unique_ptr gpu_attr_; + dnnl::impl::shifts_t input_zero_points_; + dnnl::impl::shifts_t weights_zero_points_; + dnnl::impl::shifts_t output_compensations_; dnnl_primitive_attr &operator=(const dnnl_primitive_attr &other) = delete; }; diff --git a/src/common/primitive_hashing.cpp b/src/common/primitive_hashing.cpp index eb3f33d2b86..53cc548ae6f 100644 --- a/src/common/primitive_hashing.cpp +++ b/src/common/primitive_hashing.cpp @@ -305,6 +305,20 @@ size_t get_attr_hash(const primitive_attr_t &attr) { seed = hash_combine( seed, static_cast(entry.prelu.mask)); break; + case primitive_kind::depthwise: + seed = hash_combine(seed, static_cast(entry.depthwise.alg)); + seed = hash_combine(seed, reinterpret_cast(entry.depthwise.weights_data)); + seed = hash_combine(seed, reinterpret_cast(entry.depthwise.biases_data)); + break; + case primitive_kind::quantization: + seed = hash_combine(seed, static_cast(entry.quantization.alg)); + seed = hash_combine(seed, reinterpret_cast(entry.quantization.crop_high_data)); + seed = hash_combine(seed, reinterpret_cast(entry.quantization.crop_low_data)); + seed = hash_combine(seed, reinterpret_cast(entry.quantization.input_scale_data)); + seed = hash_combine(seed, reinterpret_cast(entry.quantization.input_shift_data)); + seed = hash_combine(seed, reinterpret_cast(entry.quantization.output_scale_data)); + seed = hash_combine(seed, reinterpret_cast(entry.quantization.output_shift_data)); + break; default: assert(!"unknown post_op"); } } diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index e34b633add4..ce1a8d58e4c 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -814,6 +814,14 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { ss << delim << "prelu" << ":" << ep.mask; } break; + case primitive_kind::depthwise: { + const post_ops_t::entry_t::depthwise_t &dw = e.depthwise; + ss << delim << dw.alg; + } break; + case primitive_kind::quantization: { + const post_ops_t::entry_t::quantization_t &qt = e.quantization; + ss << delim << qt.alg; + } break; default: assert(!"unsupported post op primitive kind!"); break; } delim = attr_delim; diff --git a/src/cpu/cpu_pooling_list.cpp b/src/cpu/cpu_pooling_list.cpp index 951395c44bc..f9619ebde61 100644 --- a/src/cpu/cpu_pooling_list.cpp +++ b/src/cpu/cpu_pooling_list.cpp @@ -75,21 +75,19 @@ const std::map> &impl_list_map() { CPU_INSTANCE(nhwc_pooling_fwd_t) CPU_INSTANCE(nhwc_pooling_fwd_t) CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(nhwc_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) /* int */ CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t) CPU_INSTANCE_AARCH64(jit_uni_i8i8_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) - CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) + CPU_INSTANCE(ref_pooling_fwd_t) nullptr, }}, {{backward}, REG_BWD_PK({ diff --git a/src/cpu/gemm_convolution_utils.hpp b/src/cpu/gemm_convolution_utils.hpp index 43e9784bc44..6d7d4523ad4 100644 --- a/src/cpu/gemm_convolution_utils.hpp +++ b/src/cpu/gemm_convolution_utils.hpp @@ -84,6 +84,27 @@ struct single_gemm_conv_chunk_desc_t { dim_t w_size_ = 0; }; +namespace gemm_convolution_utils { + +struct pp_kernel_t { + static pp_kernel_t *create( + const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); + + virtual ~pp_kernel_t() = default; + + virtual void operator()(float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride) const = 0; + + virtual status_t create_kernel() { return status::success; } + +protected: + pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); + + bool do_bias_ = false; + post_ops_t post_ops_; +}; + +} // namespace gemm_convolution_utils + namespace jit_gemm_convolution_utils { template void im2col_3d(const conv_gemm_conf_t &jcp, const data_type_t *im, diff --git a/src/cpu/ref_pooling.cpp b/src/cpu/ref_pooling.cpp index 00dcb566860..280cf360cde 100644 --- a/src/cpu/ref_pooling.cpp +++ b/src/cpu/ref_pooling.cpp @@ -43,13 +43,13 @@ static inline dim_t get_offset(const memory_desc_wrapper &mdw, dim_t n, dim_t c, using namespace nstl; -template -status_t ref_pooling_fwd_t::execute_forward( +template +status_t ref_pooling_fwd_t::execute_forward( const exec_ctx_t &ctx) const { status_t status = status::success; - auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); - auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); + auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto dst = CTX_OUT_CLEAN_MEM(dst_data_t *, DNNL_ARG_DST, status); CHECK(status); auto ws = CTX_OUT_CLEAN_MEM(unsigned char *, DNNL_ARG_WORKSPACE, status); CHECK(status); @@ -172,7 +172,7 @@ status_t ref_pooling_fwd_t::execute_forward( const bool is_max_pool = alg == alg_kind::pooling_max; float base_res - = is_max_pool ? (float)numeric_limits::lowest() : 0.f; + = is_max_pool ? (float)numeric_limits::lowest() : 0.f; using ker_t = std::function; ker_t kernel = is_max_pool ? (ker_t)ker_max : (ker_t)ker_avg; @@ -191,7 +191,7 @@ status_t ref_pooling_fwd_t::execute_forward( args.dst_md = pd()->dst_md(); ref_post_ops->execute(res, args); - dst[data_p_off] = cpu::q10n::saturate_and_round(res); + dst[data_p_off] = cpu::q10n::saturate_and_round(res); }); return status::success; @@ -371,14 +371,14 @@ status_t ref_pooling_bwd_t::execute(const exec_ctx_t &ctx) const { return status::success; } -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; } // namespace cpu } // namespace impl diff --git a/src/cpu/ref_pooling.hpp b/src/cpu/ref_pooling.hpp index d6e89f5b195..1327cc39a1b 100644 --- a/src/cpu/ref_pooling.hpp +++ b/src/cpu/ref_pooling.hpp @@ -33,7 +33,7 @@ namespace dnnl { namespace impl { namespace cpu { -template +template struct ref_pooling_fwd_t : public primitive_t { struct pd_t : public cpu_pooling_fwd_pd_t { using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; @@ -43,23 +43,29 @@ struct ref_pooling_fwd_t : public primitive_t { status_t init(engine_t *engine) { using sm = primitive_attr_t::skip_mask_t; - VDISPATCH_POOLING(platform::has_data_type_support(data_type), + VDISPATCH_POOLING(platform::has_data_type_support(src_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_POOLING(platform::has_data_type_support(dst_type), VERBOSE_UNSUPPORTED_DT); VDISPATCH_POOLING(set_default_params() == status::success, VERBOSE_UNSUPPORTED_TAG); VDISPATCH_POOLING(is_fwd(), VERBOSE_BAD_PROPKIND); - VDISPATCH_POOLING(utils::everyone_is(data_type, src_md()->data_type, - dst_md()->data_type), + VDISPATCH_POOLING(utils::everyone_is(src_type, src_md()->data_type), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_POOLING(utils::everyone_is(dst_type, dst_md()->data_type), VERBOSE_UNSUPPORTED_DT); VDISPATCH_POOLING(desc()->accum_data_type == acc_type, VERBOSE_UNSUPPORTED_DT); VDISPATCH_POOLING(attr()->has_default_values(sm::post_ops), VERBOSE_UNSUPPORTED_ATTR); + // VDISPATCH_POOLING( + // ref_post_ops_t::primitive_kind_ok(attr()->post_ops_), + // VERBOSE_UNSUPPORTED_POSTOP); VDISPATCH_POOLING( - ref_post_ops_t::primitive_kind_ok(attr()->post_ops_), + attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); VDISPATCH_POOLING( - attr_.set_default_formats(dst_md(0)) == status::success, + is_supported_post_ops(), VERBOSE_UNSUPPORTED_POSTOP); bool is_training = desc_.prop_kind == prop_kind::forward_training; @@ -68,6 +74,24 @@ struct ref_pooling_fwd_t : public primitive_t { return status::success; } + + virtual bool is_supported_post_ops() const { + const auto &p = this->attr()->post_ops_; + + auto all_post_ops_supported = [&]() { + bool ok = true; + + for (int i = 0; i < p.len(); i++) { + ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::quantization); + } + return ok; + }; + + return all_post_ops_supported() && + IMPLICATION(p.len() > 0, (desc()->alg_kind == dnnl_pooling_avg_include_padding || desc()->alg_kind == dnnl_pooling_avg_exclude_padding) && + src_type != data_type::bf16); + + } }; ref_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {} @@ -80,7 +104,8 @@ struct ref_pooling_fwd_t : public primitive_t { return status::success; } - using data_t = typename prec_traits::type; + using src_data_t = typename prec_traits::type; + using dst_data_t = typename prec_traits::type; using acc_data_t = typename prec_traits::type; status_t execute(const exec_ctx_t &ctx) const override { diff --git a/src/cpu/x64/gemm_bf16_convolution.hpp b/src/cpu/x64/gemm_bf16_convolution.hpp index 7fe15a2d36b..f13a264ac1c 100644 --- a/src/cpu/x64/gemm_bf16_convolution.hpp +++ b/src/cpu/x64/gemm_bf16_convolution.hpp @@ -182,7 +182,7 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { Xbyak::Reg64 reserved_eltwise_gpr = r10; Xbyak::Opmask reserved_eltwise_maskr = k2; - Xbyak::Zmm vreg_sum_scale, vreg_bias; + Xbyak::Zmm vreg_sum_scale, vreg_bias, vreg_dw; Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27); Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(28); diff --git a/src/cpu/x64/injectors/jit_uni_depthwise_injector.cpp b/src/cpu/x64/injectors/jit_uni_depthwise_injector.cpp new file mode 100644 index 00000000000..b8ffde35bf4 --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_depthwise_injector.cpp @@ -0,0 +1,265 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" +#include "cpu/x64/injectors/injector_utils.hpp" + +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +int jit_uni_depthwise_injector_f32::aux_vecs_count(alg_kind_t depthwise_alg, bool is_broadcast) { + switch (depthwise_alg) { + case alg_kind::depthwise_scale_shift: return isa == sse41 || is_broadcast ? 1 : 0; + case alg_kind::depthwise_prelu: return 2; + default: assert(!"unsupported depthwise algorithm"); + } + + return 0; +} + +template +void jit_uni_depthwise_injector_f32::injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast) { + preserved_vecs_count = 0; + vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32::aux_vecs_count(depthwise_alg, is_broadcast); + + for (size_t i = 0; i < vecs_count; i++) { + if (preserved_vecs_count >= vecs_to_preserve) + break; + + if (i < start_idx || i >= end_idx) { + preserved_vec_idxs[preserved_vecs_count] = i; + preserved_vecs_count++; + } + } + + start_idx_tail = start_idx; + size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; + for (size_t i = 0; i < preserved_vecs_count_tail; i++) { + preserved_vec_idxs[preserved_vecs_count] = start_idx + i; + preserved_vecs_count++; + start_idx_tail = start_idx + i + 1; + } + + h->sub(h->rsp, preserved_vecs_count * vlen); + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i])); + + assign_regs(); +} + +template +void jit_uni_depthwise_injector_f32::injector_preamble_tail(size_t start_idx, size_t end_idx) { + size_t tail_vecs_to_preserve = start_idx_tail - start_idx; + int idx_off = (vecs_to_preserve - tail_vecs_to_preserve); + + if (tail_vecs_to_preserve > 0) { + h->add(h->rsp, idx_off * vlen); + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]); + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) { + preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; + } + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i])); + h->sub(h->rsp, idx_off * vlen); + + assign_regs(); + } +} + +template +void jit_uni_depthwise_injector_f32::injector_postamble() { + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]); + h->add(h->rsp, preserved_vecs_count * vlen); +} + +template +void jit_uni_depthwise_injector_f32::assign_regs() { + vmm_mask = Vmm(preserved_vec_idxs[0]); + vmm_aux0 = Vmm(preserved_vec_idxs[1]); +} + +template +void jit_uni_depthwise_injector_f32::scale_shift_compute_vector(const Vmm &vmm_src, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast, int offset) { + if (isa == sse41) { + if (is_broadcast) + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights]); + else + h->movups(vmm_mask, h->ptr[p_weights + offset]); + h->mulps(vmm_src, vmm_mask); + if (is_broadcast) + h->uni_vbroadcastss(vmm_mask, h->ptr[p_bias]); + else + h->movups(vmm_mask, h->ptr[p_bias + offset]); + h->addps(vmm_src, vmm_mask); + } else { + if (is_broadcast) { + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights]); + h->uni_vmulps(vmm_src, vmm_src, vmm_mask); + h->uni_vbroadcastss(vmm_mask, h->ptr[p_bias]); + h->uni_vaddps(vmm_src, vmm_src, vmm_mask); + } else { + h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights + offset]); + h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias+ offset]); + } + }; +} + +template +void jit_uni_depthwise_injector_f32::prelu_compute_vector(const Vmm &vmm_src, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast, int offset) { + const unsigned char _cmp_gt_os = 6; + const unsigned char _cmp_lt_os = 1; + + if (isa == sse41) { + h->pxor(vmm_mask, vmm_mask); + h->cmpps(vmm_mask, vmm_src, _cmp_gt_os); + if (is_broadcast) + h->uni_vbroadcastss(vmm_aux0, h->ptr[p_weights]); + else + h->movups(vmm_aux0, h->ptr[p_weights + offset]); + h->mulps(vmm_aux0, vmm_src); + h->blendvps(vmm_src, vmm_aux0); + } else if (isa == avx2) { + if (is_broadcast) { + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights]); + h->vmulps(vmm_aux0, vmm_src, vmm_mask); + } else + h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights + offset]); + h->vxorps(vmm_mask, vmm_mask, vmm_mask); + h->vcmpgtps(vmm_mask, vmm_src, vmm_mask); + h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask); + } else if (isa == avx512_core) { + h->vxorpd(vmm_mask, vmm_mask, vmm_mask); + h->vmovups(vmm_aux0, vmm_src); + h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os); + if (is_broadcast) { + h->uni_vbroadcastss(vmm_mask, h->ptr[p_weights]); + h->vmulps(vmm_src | k_mask, vmm_aux0, vmm_mask); + } else + h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights + offset]); + } +} + +template +void jit_uni_depthwise_injector_f32::compute_body(size_t start_idx, size_t end_idx, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast) { + for (size_t idx = start_idx; idx < end_idx; idx++) { + switch (depthwise_alg) { + case alg_kind::depthwise_scale_shift: + scale_shift_compute_vector(Vmm(idx), p_weights, p_bias, is_broadcast); break; + case alg_kind::depthwise_prelu: + prelu_compute_vector(Vmm(idx), p_weights, p_bias, is_broadcast); break; + default: assert(!"unsupported depthwise algorithm"); + } + } +} + +template +void jit_uni_depthwise_injector_f32::compute_vector_range(int start_idx, int end_idx, + const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast) { + injector_preamble(start_idx, end_idx, is_broadcast); + compute_body(start_idx_tail, end_idx, p_weights, p_bias, is_broadcast); + injector_preamble_tail(start_idx, end_idx); + compute_body(start_idx, start_idx_tail, p_weights, p_bias, is_broadcast); + injector_postamble(); +} + + +template +void jit_uni_depthwise_injector_f32::init_ptrs(const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + const Xbyak::Operand& ch_off, bool is_broadcast) { + h->mov(reg_d_weights, reinterpret_cast(post_op_.depthwise.weights_data)); + if (post_op_.depthwise.alg == alg_kind::depthwise_scale_shift) + h->mov(reg_d_bias, reinterpret_cast(post_op_.depthwise.biases_data)); + + if (!is_broadcast) { + h->add(reg_d_weights, ch_off); + if (post_op_.depthwise.alg == alg_kind::depthwise_scale_shift) + h->add(reg_d_bias, ch_off); + } +} + +template +static void push_vmm(jit_generator *host, const Vmm &vmm) { + host->sub(host->rsp, vreg_traits::vlen); + host->uni_vmovups(host->ptr[host->rsp], vmm); +} + +template +static void pop_vmm(jit_generator *host, const Vmm &vmm) { + host->uni_vmovups(vmm, host->ptr[host->rsp]); + host->add(host->rsp, vreg_traits::vlen); +} + +template +void jit_uni_depthwise_injector_f32::compute(int start_idx, int end_idx, + int vmm_d_weights_idx, int vmm_d_bias_idx, + const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + bool is_broadcast, int offset, bool need_to_preserve) { + vmm_mask = Vmm(vmm_d_weights_idx); + vmm_aux0 = Vmm(vmm_d_bias_idx); + + if (need_to_preserve) { + preserved_vecs_count = aux_vecs_count(depthwise_alg, is_broadcast); + if (preserved_vecs_count > 0) + push_vmm(h, vmm_mask); + if (preserved_vecs_count > 1) + push_vmm(h, vmm_aux0); + } + + for (int idx = start_idx; idx < end_idx; idx++) { + switch (depthwise_alg) { + case alg_kind::depthwise_scale_shift: + scale_shift_compute_vector(Vmm(idx), reg_d_weights, reg_d_bias, is_broadcast, offset); break; + case alg_kind::depthwise_prelu: + prelu_compute_vector(Vmm(idx), reg_d_weights, reg_d_bias, is_broadcast, offset); break; + default: assert(!"unsupported depthwise algorithm"); + } + } + + if (need_to_preserve) { + if (preserved_vecs_count > 1) + pop_vmm(h, vmm_aux0); + if (preserved_vecs_count > 1) + pop_vmm(h, vmm_mask); + } +} + +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; +template struct jit_uni_depthwise_injector_f32; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/injectors/jit_uni_depthwise_injector.hpp b/src/cpu/x64/injectors/jit_uni_depthwise_injector.hpp new file mode 100644 index 00000000000..3529c4f64da --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_depthwise_injector.hpp @@ -0,0 +1,135 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_DEPTHWISE_INJECTOR_HPP +#define CPU_X64_JIT_UNI_DEPTHWISE_INJECTOR_HPP + +#include + +#include "../../../common/c_types_map.hpp" +#include "../../../common/primitive_attr.hpp" +#include "../../../common/type_helpers.hpp" +#include "../../../common/utils.hpp" + +#include "../jit_generator.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +namespace depthwise_injector { + +struct static_params_t { + static_params_t(int vmm_d_weights_idx = 0, int vmm_d_bias_idx = 0, + Xbyak::Reg64 reg_d_weights = Xbyak::Reg64(0), Xbyak::Reg64 reg_d_bias = Xbyak::Reg64(0)) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias) {} + + int vmm_d_weights_idx; + int vmm_d_bias_idx; + Xbyak::Reg64 reg_d_weights; + Xbyak::Reg64 reg_d_bias; +}; + +struct dynamic_params_t { + dynamic_params_t(int vmm_d_weights_idx = 0, int vmm_d_bias_idx = 0, + Xbyak::Reg64 reg_d_weights = Xbyak::Reg64(0), Xbyak::Reg64 reg_d_bias = Xbyak::Reg64(0), + Xbyak::Reg64 reg_init_off = Xbyak::Reg64(0), const std::map vmm_idx_off = {}) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias), + reg_init_off(reg_init_off), reg_init_off_addr(0), vmm_idx_off(vmm_idx_off), useAddr(false) {} + + dynamic_params_t(int vmm_d_weights_idx, int vmm_d_bias_idx, + Xbyak::Reg64 reg_d_weights, Xbyak::Reg64 reg_d_bias, + Xbyak::Address reg_init_off, const std::map vmm_idx_off) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias), + reg_init_off(0), reg_init_off_addr(reg_init_off), vmm_idx_off(vmm_idx_off), useAddr(true) {} + + int vmm_d_weights_idx; + int vmm_d_bias_idx; + Xbyak::Reg64 reg_d_weights; + Xbyak::Reg64 reg_d_bias; + Xbyak::Reg64 reg_init_off; + Xbyak::Address reg_init_off_addr; + std::map vmm_idx_off; + bool useAddr; +}; + +} // quantization_injector + +template +struct jit_uni_depthwise_injector_f32 { + using Vmm = typename utils::conditional3::type; + + jit_uni_depthwise_injector_f32(jit_generator* host, alg_kind_t depthwise_alg_, Xbyak::Opmask k_mask_ = Xbyak::Opmask(1)) + : h(host), depthwise_alg(depthwise_alg_), k_mask(k_mask_) { + assert(utils::one_of(depthwise_alg, alg_kind::depthwise_scale_shift, alg_kind::depthwise_prelu)); + } + + jit_uni_depthwise_injector_f32(jit_generator* host, dnnl_post_ops::entry_t post_op, Xbyak::Opmask k_mask_ = Xbyak::Opmask(1)) + : h(host), post_op_(post_op), k_mask(k_mask_) { + depthwise_alg = post_op.depthwise.alg; + assert(utils::one_of(depthwise_alg, alg_kind::depthwise_scale_shift, alg_kind::depthwise_prelu)); + } + + void compute_vector_range(int start_idx, int end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false); + + void init_ptrs(const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + const Xbyak::Operand& ch_off, bool is_broadcast); + void compute(int start_idx, int end_idx, + int vmm_d_weights_idx, int vmm_d_bias_idx, + const Xbyak::Reg64& reg_d_weights, const Xbyak::Reg64& reg_d_bias, + bool is_broadcast = false, int offset = 0, bool need_to_preserve = false); + +private: + jit_generator* h; + + size_t vlen = cpu_isa_traits::vlen; + + alg_kind_t depthwise_alg; + + mutable Vmm vmm_mask; + mutable Vmm vmm_aux0; + + dnnl_post_ops::entry_t post_op_; + + Xbyak::Opmask k_mask; + + const static size_t preserved_vecs_max = 5; + size_t vecs_to_preserve = 0; + size_t vecs_count = isa == avx512_core ? 32 : 16; + size_t preserved_vecs_count = 0; + size_t preserved_vec_idxs[preserved_vecs_max] = {0}; + size_t start_idx_tail = 0; + + int aux_vecs_count(alg_kind_t elt_alg, bool is_broadcast); + + void compute_body(size_t start_idx, size_t end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false); + void injector_preamble(size_t start_idx, size_t end_idx, bool is_broadcast = false); + void injector_preamble_tail(size_t start_idx, size_t end_idx); + void injector_postamble(); + void assign_regs(); + + void scale_shift_compute_vector(const Vmm &vmm_src, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false, int offset = 0); + void prelu_compute_vector(const Vmm &vmm_src, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias, bool is_broadcast = false, int offset = 0); +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/injectors/jit_uni_postops_injector.cpp b/src/cpu/x64/injectors/jit_uni_postops_injector.cpp index dfa7ca4f533..c71fc388fe4 100644 --- a/src/cpu/x64/injectors/jit_uni_postops_injector.cpp +++ b/src/cpu/x64/injectors/jit_uni_postops_injector.cpp @@ -53,6 +53,7 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params, const lambda_jit_injectors_t &lambda_jit_injectors) : post_ops_(post_ops) , host_(host) @@ -60,6 +61,7 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( , lambda_jit_injectors_(lambda_jit_injectors) { const auto &esp = eltwise_static_params; + const auto &qsp = quantization_static_params; bool is_like_binary = false; bool is_eltwise = false; @@ -78,6 +80,17 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( esp.preserve_vmm, esp.preserve_p_table)); } else if (post_op.is_like_binary()) { is_like_binary = true; + } else if (post_op.is_depthwise()) { + depthwise_injectors.emplace_back(new jit_uni_depthwise_injector_f32( + host, + post_op + )); + } else if (post_op.is_quantization()) { + quantization_injectors.emplace_back(new jit_uni_quantization_injector_f32( + host, + post_op, + Vmm(qsp.vmm_d_weights_idx), Vmm(qsp.vmm_d_bias_idx), qsp.reg_d_weights, qsp.reg_d_bias + )); } } @@ -100,7 +113,8 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params) : jit_uni_postops_injector_t(host, post_ops, binary_static_params, - eltwise_injector::static_params_t(), lambda_jit_injectors_t()) {} + eltwise_injector::static_params_t(), quantization_injector::static_params_t(), + lambda_jit_injectors_t()) {} template jit_uni_postops_injector_t::jit_uni_postops_injector_t( @@ -108,7 +122,8 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( const binary_injector::static_params_t &binary_static_params, const lambda_jit_injectors_t &lambda_jit_injectors) : jit_uni_postops_injector_t(host, post_ops, binary_static_params, - eltwise_injector::static_params_t(), lambda_jit_injectors) {} + eltwise_injector::static_params_t(), quantization_injector::static_params_t(), + lambda_jit_injectors) {} template jit_uni_postops_injector_t::jit_uni_postops_injector_t( @@ -116,7 +131,17 @@ jit_uni_postops_injector_t::jit_uni_postops_injector_t( const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params) : jit_uni_postops_injector_t(host, post_ops, binary_static_params, - eltwise_static_params, lambda_jit_injectors_t()) {} + eltwise_static_params, + quantization_injector::static_params_t(), lambda_jit_injectors_t()) {} + +template +jit_uni_postops_injector_t::jit_uni_postops_injector_t(jit_generator *host, + const post_ops_t &post_ops, + const binary_injector::static_params_t &binary_static_params, + const quantization_injector::static_params_t &quantization_static_params) + : jit_uni_postops_injector_t(host, post_ops, binary_static_params, + eltwise_injector::static_params_t(), + quantization_static_params, lambda_jit_injectors_t()) {} // Specialization instantiations are needed to avoid instantiating ISA with // Vmm that don't make any sense like sse41 + Zmm. @@ -259,6 +284,19 @@ void jit_uni_postops_injector_t::compute_vector_range( compute_vector_range(vmm_idxs, rhs_arg_params); } +template +void jit_uni_postops_injector_t::compute_vector_range( + size_t start_idx, size_t end_idx, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) { + + injector_utils::vmm_index_set_t vmm_idxs; + for (size_t i = start_idx; i < end_idx; i++) + vmm_idxs.emplace(i); + compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); +} + template void jit_uni_postops_injector_t::compute_vector_range( size_t start_idx, size_t end_idx) { @@ -269,17 +307,69 @@ void jit_uni_postops_injector_t::compute_vector_range( template void jit_uni_postops_injector_t::compute_vector_range( const injector_utils::vmm_index_set_t &vmm_idxs, - const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) { + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) { std::size_t rhs_arg_idx = 0; + std::size_t quantization_inj_idx = 0; + std::size_t depthwise_inj_idx = 0; for (int i = 0; i < post_ops_.len(); i++) { const auto &post_op = post_ops_.entry_[i]; + if (post_op.is_eltwise()) { alg_to_eltwise_injector_.at(i).compute_vector_range(vmm_idxs); } else if (post_op.is_like_binary()) { binary_injector_->compute_vector_range( vmm_idxs, rhs_arg_idx, post_op, rhs_arg_params); ++rhs_arg_idx; + } else if (post_op.is_depthwise()) { + if (ddp.useAddr) + depthwise_injectors[depthwise_inj_idx]->init_ptrs(ddp.reg_d_weights, ddp.reg_d_bias, ddp.reg_init_off_addr, false); + else + depthwise_injectors[depthwise_inj_idx]->init_ptrs(ddp.reg_d_weights, ddp.reg_d_bias, ddp.reg_init_off, false); + + bool need_to_preserve = false; + if (post_op.depthwise.alg == dnnl_depthwise_prelu && isa == sse41) + need_to_preserve = true; + + for (auto vmm_idx : vmm_idxs) { + depthwise_injectors[depthwise_inj_idx]->compute(vmm_idx, vmm_idx + 1, + need_to_preserve ? 0 : ddp.vmm_d_weights_idx, ddp.vmm_d_bias_idx, + ddp.reg_d_weights, ddp.reg_d_bias, + false, ddp.vmm_idx_off.at(vmm_idx), need_to_preserve); + } + + depthwise_inj_idx++; + } else if (post_op.is_quantization()) { + bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize; + bool do_rounding = do_dequantization || qdp.dst_dt == dnnl_f32 || i != post_ops_.len() - 1; + + if (qdp.useAddr) + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(qdp.reg_oc_off_addr); + else + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(qdp.reg_oc_off); + for (auto vmm_idx : vmm_idxs) { + quantization_injectors[quantization_inj_idx]->compute_crop(vmm_idx, vmm_idx + 1, qdp.vmm_idx_off.at(vmm_idx)); + } + + if (qdp.useAddr) + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(qdp.reg_oc_off_addr); + else + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(qdp.reg_oc_off); + for (auto vmm_idx : vmm_idxs) { + quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(vmm_idx, vmm_idx + 1, qdp.vmm_idx_off.at(vmm_idx), do_rounding); + } + + if (qdp.useAddr) + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(qdp.reg_oc_off_addr); + else + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(qdp.reg_oc_off); + for (auto vmm_idx : vmm_idxs) { + quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(vmm_idx, vmm_idx + 1, qdp.vmm_idx_off.at(vmm_idx)); + } + + quantization_inj_idx++; } else { const auto lam = lambda_jit_injectors_.find(post_op.kind); if (lam != lambda_jit_injectors_.end()) lam->second(); @@ -292,6 +382,13 @@ void jit_uni_postops_injector_t::compute_vector_range( compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t()); } +template +void jit_uni_postops_injector_t::compute_vector_range( + const injector_utils::vmm_index_set_t &vmm_idxs, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) { + compute_vector_range(vmm_idxs, rhs_arg_params, depthwise_injector::dynamic_params_t(), quantization_injector::dynamic_params_t()); +} + template void jit_uni_postops_injector_t::prepare_table(bool gen_table) { for (auto &alg_elt_inject : alg_to_eltwise_injector_) @@ -394,6 +491,8 @@ bool post_ops_ok(const post_ops_ok_args_t &post_ops_ok_args) { *dst_d, enabled_bcast_strategy); } break; + case depthwise: if (entry.is_depthwise()) return true; break; + case quantization: if (entry.is_quantization()) return true; break; default: assert(false && "Unhandled post_op type"); } } diff --git a/src/cpu/x64/injectors/jit_uni_postops_injector.hpp b/src/cpu/x64/injectors/jit_uni_postops_injector.hpp index eee12c7d7b3..127085735ba 100644 --- a/src/cpu/x64/injectors/jit_uni_postops_injector.hpp +++ b/src/cpu/x64/injectors/jit_uni_postops_injector.hpp @@ -27,6 +27,8 @@ #include "cpu/x64/injectors/injector_utils.hpp" #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" #include "cpu/x64/jit_generator.hpp" #include @@ -95,6 +97,17 @@ class jit_uni_postops_injector_base_t { = 0; virtual void compute_vector_range(size_t start_idx, size_t end_idx) = 0; + + virtual void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) = 0; + + virtual void compute_vector_range( + size_t start_idx, size_t end_idx, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) = 0; // Generates code of post_ops chain injected to host primitive. Applied to // a single vector register index. // @rhs_arg_params: see jit_uni_binary_injector description @@ -134,13 +147,28 @@ class jit_uni_postops_injector_t : public jit_uni_postops_injector_base_t { jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params); + jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, + const binary_injector::static_params_t &binary_static_params, + const quantization_injector::static_params_t &quantization_static_params); jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, const binary_injector::static_params_t &binary_static_params, const eltwise_injector::static_params_t &eltwise_static_params, + const quantization_injector::static_params_t &quantization_static_params, const lambda_jit_injectors_t &lambda_jit_injectors); virtual ~jit_uni_postops_injector_t() = default; + void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) override; + + void compute_vector_range( + size_t start_idx, size_t end_idx, + const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params, + const depthwise_injector::dynamic_params_t &ddp, + const quantization_injector::dynamic_params_t &qdp) override; + // See `jit_uni_postops_injector_base_t::compute_vector_range(...)` void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs, const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) @@ -177,9 +205,11 @@ class jit_uni_postops_injector_t : public jit_uni_postops_injector_base_t { std::unique_ptr> binary_injector_; lambda_jit_injectors_t lambda_jit_injectors_; + nstl::vector>> depthwise_injectors; + nstl::vector>> quantization_injectors; }; -enum post_op_type { sum = 0, eltwise, binary, prelu }; +enum post_op_type { sum = 0, eltwise, binary, prelu, depthwise, quantization }; struct post_ops_ok_args_t { post_ops_ok_args_t(const cpu_isa_t isa, diff --git a/src/cpu/x64/injectors/jit_uni_quantization_injector.cpp b/src/cpu/x64/injectors/jit_uni_quantization_injector.cpp new file mode 100644 index 00000000000..2bcdb697257 --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_quantization_injector.cpp @@ -0,0 +1,248 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +template +void jit_uni_quantization_injector_f32::init_crop_ptrs(const Xbyak::Operand& ch_off) { + h->mov(reg_d_weights_, reinterpret_cast(post_op_.quantization.crop_low_data->shifts_)); + h->mov(reg_d_bias_, reinterpret_cast(post_op_.quantization.crop_high_data->shifts_)); + + if (post_op_.quantization.crop_low_data->count_ != 1 && !post_op_.quantization.crop_low_data->has_default_values()) + h->add(reg_d_weights_, ch_off); + if (post_op_.quantization.crop_high_data->count_ != 1 && !post_op_.quantization.crop_high_data->has_default_values()) + h->add(reg_d_bias_, ch_off); +} + +template +void jit_uni_quantization_injector_f32::compute_crop(int start_idx, int end_idx, int offset, bool is_scalar, bool is_broadcast) { + if (is_scalar) { + if (post_op_.quantization.crop_low_data->count_ == 1) + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_]); + else if (post_op_.quantization.crop_low_data->has_default_values()) + h->uni_vpxor(vmm_d_weights_, vmm_d_weights_, vmm_d_weights_); + else + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + } else { + if (post_op_.quantization.crop_low_data->count_ == 1) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_]); + else if (post_op_.quantization.crop_low_data->has_default_values()) + h->uni_vpxor(vmm_d_weights_, vmm_d_weights_, vmm_d_weights_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + else + h->uni_vmovups(vmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + } + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) { + for (int jj = start_idx; jj < end_idx; jj++) { + Vmm vmm_dst = Vmm(jj); + h->uni_vmaxps(vmm_dst, vmm_dst, vmm_d_weights_); + } + } + + if (is_scalar) { + if (post_op_.quantization.crop_high_data->count_ == 1) + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_]); + else if (post_op_.quantization.crop_high_data->has_default_values()) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + } else { + if (post_op_.quantization.crop_high_data->count_ == 1) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_]); + else if (post_op_.quantization.crop_high_data->has_default_values()) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + else + h->uni_vmovups(vmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + } + + for (int jj = start_idx; jj < end_idx; jj++) { + Vmm vmm_dst = Vmm(jj); + + if (vmm_d_weights_.getIdx() != vmm_d_bias_.getIdx()) + h->uni_vmaxps(vmm_dst, vmm_dst, vmm_d_weights_); + + h->uni_vminps(vmm_dst, vmm_dst, vmm_d_bias_); + } +} + +template +void jit_uni_quantization_injector_f32::init_input_scale_shift_ptrs(const Xbyak::Operand& ch_off) { + h->mov(reg_d_weights_, reinterpret_cast(post_op_.quantization.input_scale_data->scales_)); + h->mov(reg_d_bias_, reinterpret_cast(post_op_.quantization.input_shift_data->shifts_)); + + if (post_op_.quantization.input_scale_data->count_ != 1) + h->add(reg_d_weights_, ch_off); + if (post_op_.quantization.input_shift_data->count_ != 1 && !post_op_.quantization.input_shift_data->has_default_values()) + h->add(reg_d_bias_, ch_off); +} + +template +void jit_uni_quantization_injector_f32::compute_input_scale_shift(int start_idx, int end_idx, int offset, bool do_rounding, bool is_scalar, bool is_broadcast) { + if (is_scalar) { + if (post_op_.quantization.input_scale_data->count_ == 1) + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_]); + else + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + } else { + if (post_op_.quantization.input_scale_data->count_ == 1) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_]); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + else + h->uni_vmovups(vmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + } + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) { + for (int jj = start_idx; jj < end_idx; jj++) { + Vmm vmm_dst = Vmm(jj); + + h->uni_vmulps(vmm_dst, vmm_dst, vmm_d_weights_); + } + } + + if (is_scalar) { + if (post_op_.quantization.input_shift_data->count_ == 1) + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_]); + else if (post_op_.quantization.input_shift_data->has_default_values()) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + } else { + if (post_op_.quantization.input_shift_data->count_ == 1) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_]); + else if (post_op_.quantization.input_shift_data->has_default_values()) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + else + h->uni_vmovups(vmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + } + + for (int jj = start_idx; jj < end_idx; jj++) { + Vmm vmm_dst = Vmm(jj); + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) + h->uni_vaddps(vmm_dst, vmm_dst, vmm_d_bias_); + else + h->uni_vfmadd213ps(vmm_dst, vmm_d_weights_, vmm_d_bias_); + + if (do_rounding) + h->uni_vroundps(vmm_dst, vmm_dst, 0); + } +} + +template +void jit_uni_quantization_injector_f32::init_output_scale_shift_ptrs(const Xbyak::Operand& ch_off) { + if (!do_dequantization) + return; + + h->mov(reg_d_weights_, reinterpret_cast(post_op_.quantization.output_scale_data->scales_)); + h->mov(reg_d_bias_, reinterpret_cast(post_op_.quantization.output_shift_data->shifts_)); + + if (post_op_.quantization.output_scale_data->count_ != 1) + h->add(reg_d_weights_, ch_off); + if (post_op_.quantization.output_shift_data->count_ != 1 && !post_op_.quantization.output_shift_data->has_default_values()) + h->add(reg_d_bias_, ch_off); +} + +template +void jit_uni_quantization_injector_f32::compute_output_scale_shift(int start_idx, int end_idx, int offset, bool is_scalar, bool is_broadcast) { + if (!do_dequantization) + return; + + if (is_scalar) { + if (post_op_.quantization.output_scale_data->count_ == 1) + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_]); + else + h->movss(xmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + } else { + if (post_op_.quantization.output_scale_data->count_ == 1) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_]); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + else + h->uni_vmovups(vmm_d_weights_, h->ptr[reg_d_weights_ + offset]); + } + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) { + for (int jj = start_idx; jj < end_idx; jj++) { + Vmm vmm_dst = Vmm(jj); + + h->uni_vmulps(vmm_dst, vmm_dst, vmm_d_weights_); + } + } + + if (is_scalar) { + if (post_op_.quantization.output_shift_data->count_ == 1) + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_]); + else if (post_op_.quantization.output_shift_data->has_default_values()) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else + h->movss(xmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + } else { + if (post_op_.quantization.output_shift_data->count_ == 1) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_]); + else if (post_op_.quantization.output_shift_data->has_default_values()) + h->uni_vpxor(vmm_d_bias_, vmm_d_bias_, vmm_d_bias_); + else if (is_broadcast) + h->uni_vbroadcastss(vmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + else + h->uni_vmovups(vmm_d_bias_, h->ptr[reg_d_bias_ + offset]); + } + + for (int jj = start_idx; jj < end_idx; jj++) { + Vmm vmm_dst = Vmm(jj); + + if (vmm_d_weights_.getIdx() == vmm_d_bias_.getIdx()) + h->uni_vaddps(vmm_dst, vmm_dst, vmm_d_bias_); + else + h->uni_vfmadd213ps(vmm_dst, vmm_d_weights_, vmm_d_bias_); + } +} + +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; +template struct jit_uni_quantization_injector_f32; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/injectors/jit_uni_quantization_injector.hpp b/src/cpu/x64/injectors/jit_uni_quantization_injector.hpp new file mode 100644 index 00000000000..cf3546804c1 --- /dev/null +++ b/src/cpu/x64/injectors/jit_uni_quantization_injector.hpp @@ -0,0 +1,111 @@ +/******************************************************************************* +* Copyright 2019-2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_X64_JIT_UNI_QUANTIZATION_INJECTOR_HPP +#define CPU_X64_JIT_UNI_QUANTIZATION_INJECTOR_HPP + +#include + +#include "common/c_types_map.hpp" +#include "common/primitive_attr.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/x64/jit_generator.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace x64 { + +namespace quantization_injector { + +struct static_params_t { + static_params_t(int vmm_d_weights_idx = 0, int vmm_d_bias_idx = 0, + Xbyak::Reg64 reg_d_weights = Xbyak::Reg64(0), Xbyak::Reg64 reg_d_bias = Xbyak::Reg64(0)) : + vmm_d_weights_idx(vmm_d_weights_idx), vmm_d_bias_idx(vmm_d_bias_idx), reg_d_weights(reg_d_weights), reg_d_bias(reg_d_bias) {} + + int vmm_d_weights_idx; + int vmm_d_bias_idx; + Xbyak::Reg64 reg_d_weights; + Xbyak::Reg64 reg_d_bias; +}; + +struct dynamic_params_t { + dynamic_params_t(Xbyak::Reg64 reg_oc_off = Xbyak::Reg64(0), const std::map& vmm_idx_off = {}, data_type_t dst_dt = dnnl_f32) : + reg_oc_off(reg_oc_off), reg_oc_off_addr(0), vmm_idx_off(vmm_idx_off), dst_dt(dnnl_f32), useAddr(false) { + + } + + dynamic_params_t(Xbyak::Address reg_oc_off, const std::map& vmm_idx_off, data_type_t dst_dt = dnnl_f32) : + reg_oc_off(0), reg_oc_off_addr(reg_oc_off), vmm_idx_off(vmm_idx_off), dst_dt(dnnl_f32), useAddr(true) { + } + + Xbyak::Reg64 reg_oc_off; + Xbyak::Address reg_oc_off_addr; + std::map vmm_idx_off; + data_type_t dst_dt; + bool useAddr; +}; + +} // quantization_injector + +template ::Vmm> +struct jit_uni_quantization_injector_f32 { + jit_uni_quantization_injector_f32(jit_generator* host, dnnl_post_ops::entry_t post_op, + Vmm vmm_d_weights, Vmm vmm_d_bias, Xbyak::Reg64 reg_d_weights, Xbyak::Reg64 reg_d_bias) + : h(host), post_op_(post_op), vmm_d_weights_(vmm_d_weights), vmm_d_bias_(vmm_d_bias), reg_d_weights_(reg_d_weights), reg_d_bias_(reg_d_bias) { + assert(post_op.is_quantization()); + assert(utils::one_of(post_op.quantization.alg, alg_kind::quantization_quantize, alg_kind::quantization_quantize_dequantize)); + + do_dequantization = post_op_.quantization.alg == alg_kind::quantization_quantize_dequantize; + + xmm_d_weights_ = Xbyak::Xmm(vmm_d_weights.getIdx()); + xmm_d_bias_ = Xbyak::Xmm(vmm_d_bias.getIdx()); + } + + void init_crop_ptrs(const Xbyak::Operand& ch_off); + void init_input_scale_shift_ptrs(const Xbyak::Operand& ch_off); + void init_output_scale_shift_ptrs(const Xbyak::Operand& ch_off); + + void compute_crop(int start_idx, int end_idx, int offset, bool is_scalar = false, bool is_broadcast = false); + void compute_input_scale_shift(int start_idx, int end_idx, int offset, bool do_rounding, bool is_scalar = false, bool is_broadcast = false); + void compute_output_scale_shift(int start_idx, int end_idx, int offset, bool is_scalar = false, bool is_broadcast = false); + +private: + jit_generator* h; + + size_t vlen = cpu_isa_traits::vlen; + + dnnl_post_ops::entry_t post_op_; + + Vmm vmm_d_weights_; + Vmm vmm_d_bias_; + Xbyak::Xmm xmm_d_weights_; + Xbyak::Xmm xmm_d_bias_; + + Xbyak::Reg64 reg_d_weights_; + Xbyak::Reg64 reg_d_bias_; + + bool do_dequantization; +}; + +} // namespace x64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp index ae8ce032b72..5377713f9f8 100644 --- a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp @@ -49,7 +49,7 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32( const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), avx2), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -63,10 +63,12 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32( memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -147,13 +149,22 @@ void iterate(const int load_loop_blk, const int ur, const F &f) { void jit_avx2_1x1_conv_kernel_f32::apply_postops( const int load_loop_blk, const int ur, const int load_dim_tail) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { assert(ur * load_loop_blk < 14); Label store_nopost_ops; test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); jz(store_nopost_ops, T_NEAR); + std::map vmm_idx_off; + iterate(load_loop_blk, ur, load_dim_tail, + [&](const bool, const int i, const int j) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i, j), i * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -203,14 +214,14 @@ void jit_avx2_1x1_conv_kernel_f32::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { iterate(load_loop_blk, ur, load_dim_tail, [&](const bool, const int i, const int j) { vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i, j)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } L(store_nopost_ops); } @@ -594,6 +605,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate() { mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); if (jcp.prop_kind == backward_weights) mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto generate_load_loop_body = [&](int load_loop_blk) { generate_bcast_loop(load_loop_blk); @@ -627,6 +639,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate() { default: assert(!"invalid prop_kind"); } sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; Label load_loop_blk_8; @@ -752,6 +765,9 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, const int prelu_ind = post_ops.find(primitive_kind::prelu, 0, dw_conv_ind); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; + if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), @@ -784,7 +800,7 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, jcp.ic = rnd_up(jcp.ic, simd_w); } - if (jcp.with_eltwise || jcp.with_binary) + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) VDISPATCH_CONV_IC(jcp.isa >= avx2, VERBOSE_UNSUPPORTED_FEATURE, "eltwise and binary post-ops not implemented on isa"); @@ -793,7 +809,7 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(jcp.isa, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp index b7d429e00e4..cbf2b98ecc1 100644 --- a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp @@ -85,6 +85,12 @@ struct jit_avx2_1x1_conv_kernel_f32 : public jit_generator { constexpr static int reg_dw_binary_output_off = 3 * reg64_size_; constexpr static int stack_space_needed = 4 * reg64_size_; + reg64_t reg_oc_off = load_loop_iter; + reg64_t reg_d_weights = aux_reg_bcast_data; + reg64_t reg_d_bias = reduce_loop_iter; // todo: [AV] check, conflict with out_off_oprnd (r15) + ymm_t ymm_d_weights = Xbyak::Ymm(14); + ymm_t ymm_d_bias = Xbyak::Ymm(15); + ymm_t vreg_bcast = ymm_t(15); ymm_t vtmp = ymm_t(14); diff --git a/src/cpu/x64/jit_avx2_1x1_convolution.cpp b/src/cpu/x64/jit_avx2_1x1_convolution.cpp index 9aa8a335698..303dd5d1fdb 100644 --- a/src/cpu/x64/jit_avx2_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx2_1x1_convolution.cpp @@ -217,6 +217,8 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = static_cast(p.output_data) - dst_off; + p.oc_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&p); }; @@ -296,6 +298,8 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ch * jcp_dw->ch_block * sizeof(float); + (*dw_jit_ker)(&par_conv_dw); for (int i = 0; i < jcp_dw->kh; ++i) diff --git a/src/cpu/x64/jit_avx2_1x1_convolution.hpp b/src/cpu/x64/jit_avx2_1x1_convolution.hpp index 4d942d1a50a..f4d30ea47e0 100644 --- a/src/cpu/x64/jit_avx2_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx2_1x1_convolution.hpp @@ -72,6 +72,9 @@ struct jit_avx2_1x1_convolution_fwd_t : public primitive_t { VDISPATCH_CONV( attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV( + !this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); const convolution_desc_t *conv_d = desc(); const memory_desc_t *src_d = src_md(); @@ -326,12 +329,12 @@ struct jit_avx2_1x1_convolution_fwd_t : public primitive_t { if (isa == avx2) { CHECK(safe_ptr_assign(kernel_dw_avx2, new dw_conv_kernel_t( - *(pd()->jcp_dw_), *pd()->dst_md(0)))); + *(pd()->jcp_dw_), *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_avx2->create_kernel()); } else { CHECK(safe_ptr_assign(kernel_dw_sse41, new dw_conv_kernel_t( - *(pd()->jcp_dw_), *pd()->dst_md(0)))); + *(pd()->jcp_dw_), *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_sse41->create_kernel()); } } diff --git a/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp b/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp index eda08c5e741..fc2b3306cbb 100644 --- a/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp @@ -54,7 +54,7 @@ jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), avx2), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -68,10 +68,12 @@ jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32( memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -208,11 +210,19 @@ void iterate(const int load_loop_blk, const int ur, const F &f) { void jit_avx2_conv_fwd_kernel_f32::apply_postops( const int oc_blocks, const int ur_w, const int oc_tail) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { Label regular_store; test(reg_ci_flag, FLAG_IC_LAST); je(regular_store, T_NEAR); + std::map vmm_idx_off; + iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { + vmm_idx_off.insert({get_ymm_idx(ur_w, i, j), i * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {ymm_d_weights.getIdx(), ymm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -244,14 +254,14 @@ void jit_avx2_conv_fwd_kernel_f32::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { vmm_idxs.emplace(get_ymm_idx(ur_w, i, j)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } L(regular_store); } @@ -676,6 +686,8 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, const int binary_ind = post_ops.find(primitive_kind::binary); const int prelu_ind = post_ops.find(primitive_kind::prelu); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -696,7 +708,7 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, if (mimo) jcp.ic = rnd_up(jcp.ic, simd_w); } - if (jcp.with_eltwise || jcp.with_binary) + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) VDISPATCH_CONV_IC(mayiuse(avx2), VERBOSE_UNSUPPORTED_FEATURE, "eltwise and binary post-ops not implemented on isa"); @@ -705,7 +717,7 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(jcp.isa, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp b/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp index 9050c113a2f..0fb3c7243f9 100644 --- a/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_avx2_conv_kernel_f32.hpp @@ -80,6 +80,12 @@ struct jit_avx2_conv_fwd_kernel_f32 : public jit_generator { Xbyak::Ymm ytmp = Xbyak::Ymm(14); + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = ki_iter; + + Xbyak::Ymm ymm_d_weights = Xbyak::Ymm(14); + Xbyak::Ymm ymm_d_bias = Xbyak::Ymm(15); + inline void oh_step_unroll_kw( int ur_w, int pad_l, int pad_r, int oc_blocks); inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); diff --git a/src/cpu/x64/jit_avx2_convolution.cpp b/src/cpu/x64/jit_avx2_convolution.cpp index c9daceacd62..bde9a00d2bd 100644 --- a/src/cpu/x64/jit_avx2_convolution.cpp +++ b/src/cpu/x64/jit_avx2_convolution.cpp @@ -138,7 +138,7 @@ void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.flags |= FLAG_IC_FIRST; } - if ((jcp.with_eltwise || jcp.with_binary) + if ((jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) && icb + 1 == jcp.nb_ic) par_conv.flags |= FLAG_IC_LAST; @@ -164,6 +164,7 @@ void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); par_conv.dst_orig = dst; + par_conv.oc_off = _oc * oc_bias_scale * sizeof(float); (*kernel_)(&par_conv); } diff --git a/src/cpu/x64/jit_avx2_convolution.hpp b/src/cpu/x64/jit_avx2_convolution.hpp index 1bf683eb7f5..b848264231f 100644 --- a/src/cpu/x64/jit_avx2_convolution.hpp +++ b/src/cpu/x64/jit_avx2_convolution.hpp @@ -59,6 +59,9 @@ struct jit_avx2_convolution_fwd_t : public primitive_t { VDISPATCH_CONV( attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV( + !this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); CHECK(jit_avx2_conv_fwd_kernel_f32::init_conf( jcp_, *desc(), src_md(), weights_md(), dst_md(), *attr())); diff --git a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp index 9ebd31809ef..9b756107e54 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp @@ -51,7 +51,7 @@ jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel( const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name()), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -66,10 +66,12 @@ jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -169,6 +171,16 @@ static void iterate(const int load_loop_blk, const int ur, const F &fun) { void jit_avx512_common_1x1_conv_kernel::apply_postops( const bool is_out_layout_nxc, const int load_loop_blk, const int ur) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.load_block * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -194,7 +206,7 @@ void jit_avx512_common_1x1_conv_kernel::apply_postops( mov(abi_param1, ptr[rsp + reg_abi_param1_backup]); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); if (jcp.with_dw_conv) { sub(aux_reg_output_data, EVEX_compress_addr(rsp, reg_dw_binary_output_off)); @@ -205,7 +217,7 @@ void jit_avx512_common_1x1_conv_kernel::apply_postops( vmm_idxs.emplace( vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -310,7 +322,7 @@ void jit_avx512_common_1x1_conv_kernel::reduce_loop( L(store_noadd); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { Label store_nopostops; test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); jz(store_nopostops, T_NEAR); @@ -433,6 +445,7 @@ void jit_avx512_common_1x1_conv_kernel::generate() { mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); if (jcp.prop_kind == backward_weights) mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); const int load_dim_tail = (one_of(jcp.prop_kind, forward_training, forward_inference) @@ -457,6 +470,7 @@ void jit_avx512_common_1x1_conv_kernel::generate() { } bcast_loop(load_loop_blk); add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); const size_t offst_with_dw_conv = load_loop_blk * jcp.load_block * jcp.typesize_out * (is_out_layout_nxc(jcp) @@ -627,6 +641,8 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); const int prelu_ind = post_ops.find(primitive_kind::prelu, 0, dw_conv_ind); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them @@ -657,7 +673,7 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp index 9a926841fc0..e71f8e138fd 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp @@ -67,7 +67,7 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator { reg64_t reg_load_loop_work = rsi; reg64_t reg_reduce_loop_work = r11; reg64_t reg_bcast_loop_iter = rdx; - reg64_t reduce_loop_iter = abi_param1; + reg64_t reduce_loop_iter = r13; reg64_t reg_reduce_pos_flag = rax; reg64_t reg_output_stride = r13; reg64_t reg_bias_data = r12; @@ -87,6 +87,13 @@ struct jit_avx512_common_1x1_conv_kernel : public jit_generator { constexpr static int reg_dw_binary_output_off = 3 * reg64_size_; constexpr static int stack_space_needed = 4 * reg64_size_; + reg64_t reg_oc_off = abi_param1; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = r13; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + void bcast_loop(int load_loop_blk); void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); diff --git a/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp index c803cd539cb..8da51fb9f4b 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_common_1x1_convolution.cpp @@ -220,6 +220,7 @@ void jit_avx512_common_1x1_convolution_fwd_t(p.output_data) - dst_off; + p.oc_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); (*kernel_)(&p); }; @@ -368,6 +369,8 @@ void jit_avx512_common_1x1_convolution_fwd_tattr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); const convolution_desc_t *conv_d = desc(); const memory_desc_t *src_d = src_md(); @@ -295,7 +298,7 @@ struct jit_avx512_common_1x1_convolution_fwd_t : public primitive_t { if (pd()->jcp_.with_dw_conv) { CHECK(safe_ptr_assign(kernel_dw_, new dw_conv_kernel_t( - pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0)))); + pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_->create_kernel()); } diff --git a/src/cpu/x64/jit_avx512_common_conv_kernel.cpp b/src/cpu/x64/jit_avx512_common_conv_kernel.cpp index 21f353731cf..feaa2fe90e0 100644 --- a/src/cpu/x64/jit_avx512_common_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_common_conv_kernel.cpp @@ -96,7 +96,7 @@ _jit_avx512_common_conv_fwd_kernel::_jit_avx512_common_conv_fwd_kernel( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name()), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -111,10 +111,12 @@ _jit_avx512_common_conv_fwd_kernel::_jit_avx512_common_conv_fwd_kernel( use_exact_tail_scalar_bcast}; const binary_injector::static_params_t static_params { this->param1, rhs_args_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -144,6 +146,15 @@ static void iterate(const int nb_oc_blocking, const int ur_w, const F &fun) { template void _jit_avx512_common_conv_fwd_kernel::apply_postops(int ur_w) { + std::map vmm_idx_off; + iterate(jcp.nb_oc_blocking, ur_w, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vmm_out_idx(i_ur, i_load), i_load * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -164,13 +175,13 @@ void _jit_avx512_common_conv_fwd_kernel::apply_postops(int ur_w) { } }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(jcp.nb_oc_blocking, ur_w, [&](const bool, const int i_load, const int i_ur) { vmm_idxs.emplace(vmm_out_idx(i_ur, i_load)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -223,7 +234,7 @@ void _jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) { L(post_ops_label); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { test(reg_channel, FLAG_IC_LAST); jz(store_label, T_NEAR); @@ -956,6 +967,8 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp, const int binary_ind = post_ops.find(primitive_kind::binary); const int prelu_ind = post_ops.find(primitive_kind::prelu); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -964,7 +977,7 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_common_conv_kernel.hpp b/src/cpu/x64/jit_avx512_common_conv_kernel.hpp index 163ff4ddc78..e8f1ec896d0 100644 --- a/src/cpu/x64/jit_avx512_common_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_common_conv_kernel.hpp @@ -113,6 +113,12 @@ struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { std::unique_ptr> postops_injector_; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = reg_kj; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + inline void prepare_output(int ur_w); inline void apply_postops(int ur_w); inline void store_output(int ur_w); diff --git a/src/cpu/x64/jit_avx512_common_convolution.cpp b/src/cpu/x64/jit_avx512_common_convolution.cpp index e95dadba586..31027a2221f 100644 --- a/src/cpu/x64/jit_avx512_common_convolution.cpp +++ b/src/cpu/x64/jit_avx512_common_convolution.cpp @@ -36,7 +36,7 @@ using jit_conv_ker_t = void (*)(jit_conv_call_s *); inline void jit_conv_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, - int channel, int kh_padding, int reduce_work, int load_work) { + int channel, int kh_padding, int reduce_work, int load_work, int oc_off) { p.src = src; p.dst = dst; p.filt = filt; @@ -47,6 +47,7 @@ inline void jit_conv_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, p.kh_padding = kh_padding; p.reduce_work = reduce_work; p.load_work = load_work; + p.oc_off = oc_off; ker(&p); } @@ -54,17 +55,17 @@ inline void jit_conv_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, inline void jit_conv_ker_pipeline_iw_thr(const jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, int channel, int kh_padding, int iwb, int reduce_work, - int load_work) { + int load_work, int oc_off) { p.iwb = iwb; jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - reduce_work, load_work); + reduce_work, load_work, oc_off); } inline void jit_conv_3d_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, int channel, int kh_padding, int kd_padding, - int reduce_work, int load_work) { + int reduce_work, int load_work, int oc_off) { p.src = src; p.dst = dst; p.filt = filt; @@ -76,15 +77,17 @@ inline void jit_conv_3d_ker_pipeline(const jit_conv_ker_t ker, p.kd_padding = kd_padding; p.reduce_work = reduce_work; p.load_work = load_work; + p.oc_off = oc_off; ker(&p); } + // The special case for the driver with ow-parallelization (FWD) inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, int channel, int kh_padding, int owb, int reduce_work, int load_work, const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, - int flags) { + int flags, int oc_off) { p.owb = owb; p.flags = flags; @@ -92,7 +95,7 @@ inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - reduce_work, load_work); + reduce_work, load_work, oc_off); } // The special case for the driver with ow-parallelization (FWD) @@ -101,7 +104,7 @@ inline void jit_conv_3d_ker_pipeline_ow_thr(const jit_conv_ker_t ker, jit_conv_call_s &p, const void *src, const void *dst, const void *filt, const void *bias, int channel, int kh_padding, int kd_padding, int owb, int reduce_work, int load_work, const void *post_ops_binary_rhs_arg_vec, - const void *dst_orig, int flags) { + const void *dst_orig, int flags, int oc_off) { p.dst_orig = dst_orig; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; @@ -110,7 +113,7 @@ inline void jit_conv_3d_ker_pipeline_ow_thr(const jit_conv_ker_t ker, p.flags = flags; jit_conv_3d_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - kd_padding, reduce_work, load_work); + kd_padding, reduce_work, load_work, oc_off); } inline void jit_conv_ker_pipeline_bwd_w(const jit_conv_ker_t ker, @@ -118,7 +121,7 @@ inline void jit_conv_ker_pipeline_bwd_w(const jit_conv_ker_t ker, const void *bias, int channel, int kh_padding, size_t reduce_work, size_t load_work) { jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, - reduce_work, load_work); + reduce_work, load_work, 0); } void jit_conv_2d_ker_bwd_w_pipeline(const jit_conv_ker_t ker, @@ -278,9 +281,10 @@ void jit_avx512_common_convolution_fwd_thas_default_values( primitive_attr_t::skip_mask_t::post_ops, dst_type), VERBOSE_UNSUPPORTED_ATTR); + VDISPATCH_CONV(!this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); CHECK(jit_avx512_common_conv_fwd_kernel::init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp index e57e4b9ed76..ca96d28b8ae 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp @@ -41,7 +41,7 @@ jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), avx512_core_amx), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; const auto &rhs_addr_reg = bin_injector_helper_reg_1; const auto &rhs_helper_reg = bin_injector_helper_reg_2; @@ -58,10 +58,12 @@ jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params = + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -226,22 +228,27 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_sum(const Zmm zmm_out, void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_postops(const Zmm zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, - const Xbyak::Address &addr, const size_t off, const bool mask_flag) { + const Xbyak::Address &addr, const size_t off, const bool mask_flag, const int ocb) { if (jcp.with_eltwise || jcp.with_binary - || (jcp.with_sum && p_sum_scale != nullptr)) { + || (jcp.with_sum && p_sum_scale != nullptr) || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + vmm_idx_off.insert({zmm_out.getIdx(), ocb * jcp.oc_block * sizeof(float)}); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt}; + + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; + apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag); const auto vmm_idx = zmm_out.getIdx(); if (jcp.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, out_ptr); rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off); if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); - - postops_injector_->compute_vector(vmm_idx, rhs_arg_params); - } else { - postops_injector_->compute_vector(vmm_idx); } + + postops_injector_->compute_vector_range({(size_t)vmm_idx}, rhs_arg_params, ddp, qdp); } } @@ -451,7 +458,7 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_int8( if (jcp.with_bias) vaddps(zmm_out_msk, zmm_out, zmm_bias); - apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag); + apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag, ocb); if (jcp.dst_scale) { mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); @@ -564,7 +571,7 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_bf16( static constexpr auto skip_sum_in_injection = nullptr; apply_postops(zmm_out, skip_sum_in_injection, skip_sum_in_injection, addr, - off, mask_flag); + off, mask_flag, ocb); if (jcp.dst_dt == data_type::bf16) { store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag); @@ -1131,6 +1138,12 @@ status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); jcp.sum_dt = p.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = p.entry_[sum_ind].sum.dt; + + jcp.with_depthwise = p.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = p.find(primitive_kind::quantization) != -1; + jcp.post_ops = p; jcp.is_fast_postops = is_fast_postops(jcp); @@ -1139,7 +1152,7 @@ status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, const bool sum_requires_scale_one = sum_at_pos_0_only; const bool sum_requires_zp_zero = sum_at_pos_0_only; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp index 1fa1ba63d2a..d78d4ef440e 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp @@ -114,6 +114,12 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { const Xbyak::Opmask ktail_mask = k2; + const Xbyak::Reg64 reg_d_weights = reg_last_h; + const Xbyak::Reg64 reg_d_bias = reg_oc_blocks; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + bool is_bf16() const; void init_runtime_counters(); @@ -135,7 +141,7 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { Xbyak::Zmm zmm_out(const int idx) { const int upper_limit = is_bf16() ? zmm_idx_limit_bf16 : zmm_idx_limit_int8; - assert(upper_limit > idx); +// assert(upper_limit > idx); MAYBE_UNUSED(upper_limit); return Xbyak::Zmm(idx); } @@ -151,7 +157,7 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { const bool mask_flag); void apply_postops(const Xbyak::Zmm zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, const Xbyak::Address &addr, - const size_t off, const bool mask_flag); + const size_t off, const bool mask_flag, const int ocb); static bool is_fast_postops(const jit_conv_conf_t &jcp); void store_output_vectors_int8(int ocb, int osb); void store_output_vector_int8( diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp index 5a61a9fa38f..af78ae8d280 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp @@ -168,7 +168,7 @@ status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward( = jcp.src_zero_point ? zp_compensation + oc : nullptr; p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; - + p.oc_off = oc * sizeof(float); p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp index 9479230dc13..e6aae72d11d 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp @@ -53,7 +53,7 @@ struct jit_avx512_core_amx_1x1_convolution_fwd_t : public primitive_t { && utils::one_of(dst_md(0)->data_type, f32, bf16)) && IMPLICATION(with_bias(), utils::one_of(weights_md(1)->data_type, f32, bf16)) - && attr()->has_default_values(smask_t::post_ops); + && attr()->has_default_values(smask_t::post_ops, dst_md(0)->data_type); bool is_int8_convolution = utils::one_of(src_md(0)->data_type, s8, u8) && weights_md(0)->data_type == s8 diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index d7141460dd9..50c2ab1bca0 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -1064,7 +1064,7 @@ jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), avx512_core_amx), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; const auto &rhs_addr_reg = bin_injector_helper_reg_1; const auto &rhs_helper_reg = bin_injector_helper_reg_2; @@ -1083,9 +1083,12 @@ jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t( const binary_injector::static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params = + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } copy_to_pbuffer_ = utils::make_unique(jcp); @@ -1395,22 +1398,27 @@ void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out, void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, - const Xbyak::Address &addr, const size_t off, const bool mask_flag) { + const Xbyak::Address &addr, const size_t off, const bool mask_flag, const int ocb) { if (jcp.with_eltwise || jcp.with_binary - || (jcp.with_sum && p_sum_scale != nullptr)) { + || (jcp.with_sum && p_sum_scale != nullptr) || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + vmm_idx_off.insert({zmm_out.getIdx(), ocb * jcp.oc_block * sizeof(float)}); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off, jcp.dst_dt}; + + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; + apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag); const auto vmm_idx = zmm_out.getIdx(); if (jcp.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_out_ptr); rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off); if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); - - postops_injector_->compute_vector(vmm_idx, rhs_arg_params); - } else { - postops_injector_->compute_vector(vmm_idx); } + + postops_injector_->compute_vector_range({(size_t)vmm_idx}, rhs_arg_params, ddp, qdp); } } @@ -1455,7 +1463,7 @@ void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16( static constexpr auto skip_sum_injection = nullptr; apply_postops(zmm_out, skip_sum_injection, skip_sum_injection, addr, off, - mask_flag); + mask_flag, ocb); if (jcp.dst_dt == data_type::bf16) { store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag); @@ -1524,7 +1532,7 @@ void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8( EVEX_compress_addr(reg_ptr_scales, scale_offset)); if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias); - apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag); + apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag, ocb); if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); } if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); } @@ -2499,6 +2507,8 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, const int prelu_ind = p.find(primitive_kind::prelu); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); jcp.sum_dt = p.get_sum_dt(jcp.dst_dt); + jcp.with_depthwise = p.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = p.find(primitive_kind::quantization) != -1; jcp.post_ops = p; @@ -2507,7 +2517,7 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, const bool sum_requires_scale_one = sum_at_pos_0_only; const bool sum_requires_zp_zero = sum_at_pos_0_only; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); @@ -2557,9 +2567,10 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, jcp.nb_oc_blocking_thr_chunk = 1; - const int target_palette = amx::get_target_palette(); - jcp.max_tiles = amx::get_max_tiles(target_palette); - jcp.full_tile_width = amx::get_max_rows(target_palette); + // @todo This change must be explained with a comment + // const int target_palette = amx::get_target_palette(); + jcp.max_tiles = 8; //amx::get_max_tiles(target_palette); + jcp.full_tile_width = 16; //amx::get_max_rows(target_palette); VDISPATCH_CONV_IC(!(jcp.max_tiles != 8 || jcp.full_tile_width != 16), VERBOSE_BLOCKING_FAIL, "bad blocking parameters"); diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp index 4aa52520753..3ba6367f0c9 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp @@ -342,6 +342,12 @@ struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator { const Xbyak::Reg64 bin_injector_helper_reg_2 = r15; const Xbyak::Reg64 bin_injector_helper_reg_3 = r11; + const Xbyak::Reg64 reg_d_weights = reg_zp_compensation; + const Xbyak::Reg64 reg_d_bias = reg_src_zero_point; + + const Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + const Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + // AUX: Steps, shifts and offsets size_t get_inp_icb_step() const; size_t get_wei_icb_step() const; @@ -390,7 +396,7 @@ struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator { const bool mask_flag); void apply_postops(const Xbyak::Zmm &zmm_out, const float *p_sum_scale, const int32_t *p_sum_zp, const Xbyak::Address &addr, - const size_t off, const bool mask_flag); + const size_t off, const bool mask_flag, const int ocb); inline void store_output_ymm_bf16( const int idx, const Xbyak::Address &addr, const bool mask_flag); void store_output_vector_bf16( diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp index 2377ca8bf47..336f6a510c5 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp @@ -393,6 +393,7 @@ jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( p.oc_blocks = occ * jcp.nb_oc_blocking; + p.oc_off = oc * sizeof(float); p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; @@ -787,6 +788,7 @@ status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( p.oc_blocks = occ * jcp.nb_oc_blocking; + p.oc_off = oc * sizeof(float); p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp index 6be4ec34f25..85e64d2f2aa 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp @@ -56,7 +56,7 @@ struct jit_avx512_core_amx_convolution_fwd_t : public primitive_t { && utils::one_of(dst_md(0)->data_type, f32, bf16)) && IMPLICATION(with_bias(), utils::one_of(weights_md(1)->data_type, f32, bf16)) - && attr()->has_default_values(smask_t::post_ops); + && attr()->has_default_values(smask_t::post_ops, dst_md(0)->data_type); bool is_int8_convolution = utils::one_of(src_md(0)->data_type, s8, u8) && weights_md(0)->data_type == s8 diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp index 6f64900c0a6..f9c596a181d 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp @@ -46,7 +46,7 @@ jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel( const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), avx512_core_bf16), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -61,10 +61,12 @@ jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } if (!isa_has_bf16(jcp.isa)) @@ -181,7 +183,17 @@ static void iterate(const int load_loop_blk, const int ur, const F &f) { void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops( const int load_loop_blk, const int ur) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.oc_block * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -237,7 +249,7 @@ void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { @@ -246,7 +258,7 @@ void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops( vmm_idxs.emplace( vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -854,6 +866,8 @@ void jit_avx512_core_bf16_1x1_conv_kernel::reduce_loop( mov(aux_reg_bcast_data, aux1_reg_bcast_data); init(); + push(reg_oc_off); + mov(reduce_loop_iter, reg_reduce_loop_work); Label reduce_loop_exit; cmp(reduce_loop_iter, jcp.reduce_loop_unroll); @@ -875,6 +889,9 @@ void jit_avx512_core_bf16_1x1_conv_kernel::reduce_loop( fma_block(true); L(reduce_loop_exit); + + pop(reg_oc_off); + store(); } @@ -1051,6 +1068,7 @@ void jit_avx512_core_bf16_1x1_conv_kernel::generate() { mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); } + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto load_loop_body = [&](int load_loop_blk) { Label no_update_mask, update_mask_done; if (load_dim_tail) { @@ -1072,6 +1090,8 @@ void jit_avx512_core_bf16_1x1_conv_kernel::generate() { mov(reg_load_loop_work, ptr[rsp + reg_load_loop_work_off]); add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); + const size_t off_with_dw_conv = load_loop_blk * jcp.load_block * jcp.typesize_out * (is_out_layout_nxc() @@ -1260,6 +1280,9 @@ status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf( const int prelu_ind = post_ops.find(primitive_kind::prelu, 0, dw_conv_ind); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; + if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), @@ -1273,7 +1296,7 @@ status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf( static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp index b7b457f8cc9..581a44b4e12 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp @@ -108,6 +108,12 @@ struct jit_avx512_core_bf16_1x1_conv_kernel : public jit_generator { Xbyak::Opmask half_mask = Xbyak::Opmask(6); Xbyak::Opmask half_mask_hi = Xbyak::Opmask(5); Xbyak::Label dst_prm_table; + reg64_t reg_oc_off = abi_param1; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = aux_reg_bcast_data; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); constexpr static int reg64_size_ = sizeof(int64_t); constexpr static int bcast_loop_work_offt = 0; diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp index fba53b67add..2b5548842b3 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp @@ -274,6 +274,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = static_cast(p.output_data) - dst_off * dst_d.data_type_size(); + p.oc_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); (*kernel_)(&p); }; @@ -382,6 +383,8 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t::execute_forward_thr( = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ch * jcp_dw->ch_block * sizeof(float); + (*kernel_dw_)(&par_conv_dw); for (int i = 0; i < jcp_dw->kh; ++i) diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp index b03ed83e890..67052c021a9 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp @@ -339,7 +339,7 @@ struct jit_avx512_core_bf16_1x1_convolution_fwd_t : public primitive_t { if (pd()->jcp_.with_dw_conv) { CHECK(safe_ptr_assign(kernel_dw_, - new dw_conv_kernel_t(*(pd()->jcp_dw_), *pd()->dst_md(0)))); + new dw_conv_kernel_t(*(pd()->jcp_dw_), *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); CHECK(kernel_dw_->create_kernel()); } diff --git a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp index 5d0e49834de..efd749faecd 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp @@ -100,7 +100,7 @@ _jit_avx512_core_bf16_fwd_kernel::_jit_avx512_core_bf16_fwd_kernel( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), avx512_core_bf16), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -118,10 +118,12 @@ _jit_avx512_core_bf16_fwd_kernel::_jit_avx512_core_bf16_fwd_kernel( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, static_params, quantization_static_params); } if (!isa_has_bf16(jcp.isa)) bf16_emu_ = utils::make_unique(this, @@ -168,7 +170,16 @@ static void iterate(const int nb_oc_block, const int ur_w, const F &f) { template void _jit_avx512_core_bf16_fwd_kernel::apply_postops(int ur_w) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(jcp.nb_oc_blocking, ur_w, + [&](const bool, const int k, const int j) { + vmm_idx_off.insert({vmm_dst_idx(j, k), k * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -205,7 +216,7 @@ void _jit_avx512_core_bf16_fwd_kernel::apply_postops(int ur_w) { jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { @@ -213,7 +224,7 @@ void _jit_avx512_core_bf16_fwd_kernel::apply_postops(int ur_w) { [&](const bool, const int k, const int j) { vmm_idxs.emplace(vmm_dst_idx(j, k)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -1000,6 +1011,8 @@ status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp, const int binary_ind = post_ops.find(primitive_kind::binary); const int prelu_ind = post_ops.find(primitive_kind::prelu); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0; if (is_data_layout_nxc) @@ -1017,7 +1030,7 @@ status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp index 3e7cc332423..175520ba3b4 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp @@ -127,10 +127,16 @@ struct _jit_avx512_core_bf16_fwd_kernel : public jit_generator { constexpr static int off_reg_ker_ = 8; constexpr static int stack_space_needed_ = 16; - std::unique_ptr> + std::unique_ptr> postops_injector_; std::unique_ptr bf16_emu_; + reg64_t reg_d_weights = r15; + reg64_t reg_d_bias = reg_kj; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + inline void prepare_dst(int ur_w); void apply_postops(int ur_w); inline void store_dst(int ur_w); diff --git a/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp b/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp index 82f55def024..13127783552 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_convolution.cpp @@ -132,6 +132,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_1d( par_conv.dst_orig = dst; par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + par_conv.oc_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&par_conv); if (jcp.loop_order == loop_cwgn) { @@ -255,6 +257,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d( par_conv.dst_orig = dst; par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + par_conv.oc_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&par_conv); src_w += src_h_stride * jcp.stride_h; @@ -392,6 +396,8 @@ void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( par_conv.dst_orig = dst; par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + par_conv.oc_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); + (*kernel_)(&par_conv); src_w += src_h_stride * jcp.stride_h; diff --git a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp index 4cd112b00cf..dca5a0a4c7f 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp @@ -34,9 +34,9 @@ using namespace Xbyak; using namespace dnnl::impl::utils; jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) - : jit_generator(jit_name()), jcp(ajcp) { - if (jcp.with_eltwise || jcp.with_binary) { + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr) + : jit_generator(jit_name()), jcp(ajcp), attr_(attr) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -52,10 +52,12 @@ jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16( use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } if (!isa_has_bf16(jcp.isa)) bf16_emu_ = utils::make_unique(this, @@ -202,7 +204,15 @@ static void iterate(const int ur_ch_blocks, const int ur_w, const F &f) { void jit_avx512_dw_conv_fwd_kernel_bf16::apply_postops( int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { - if (this->jcp.with_eltwise || this->jcp.with_binary) { + if (this->jcp.with_eltwise || this->jcp.with_binary || this->jcp.with_depthwise || this->jcp.with_quantization) { + std::map vmm_idx_off; + iterate(ur_ch_blocks, ur_w, [&](int ch, int ow, int) { + vmm_idx_off.insert({get_acc_reg_idx(ch * ur_w + ow), ch * jcp.ch_block * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { @@ -244,20 +254,20 @@ void jit_avx512_dw_conv_fwd_kernel_bf16::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params); + vmm_idxs, rhs_arg_params, ddp, qdp); } else if (last_ch_block_flag) postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params_tail); + vmm_idxs, rhs_arg_params_tail, ddp, qdp); else /* if (!last_ch_block_flag) */ postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params); + vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { iterate(ur_ch_blocks, ur_w, [&](int ch, int ow, int) { vmm_idxs.emplace(get_acc_reg_idx(ch * ur_w + ow)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } diff --git a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp index 7ec11eeed79..91be4a79ae4 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp @@ -25,6 +25,7 @@ #include "cpu/x64/jit_primitive_conf.hpp" #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" namespace dnnl { namespace impl { @@ -35,9 +36,10 @@ struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_fwd_kernel_bf16) jit_avx512_dw_conv_fwd_kernel_bf16( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md); + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t& attr); jit_conv_conf_t jcp; + const primitive_attr_t& attr_; private: using reg64_t = const Xbyak::Reg64; @@ -70,6 +72,12 @@ struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator { mask_t ktail_mask = k_oc_tail_mask; mask_t k_ch_tail_mask_extended = Xbyak::Opmask(3); + reg64_t reg_d_weights = abi_not_param1; + reg64_t reg_d_bias = iter_kh; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0); Xbyak::Zmm zmm_src_reg = Xbyak::Zmm(1); Xbyak::Zmm zmm_prev_dst = Xbyak::Zmm(31); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp index b8e581de704..35445f2dcdb 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -15,6 +15,7 @@ *******************************************************************************/ #include +#include #include "common/c_types_map.hpp" #include "common/memory.hpp" @@ -52,7 +53,7 @@ _jit_avx512_core_x8s8s32x_1x1_conv_kernel:: , jcp(ajcp) , attr_(attr) , postops_injector_(nullptr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -70,10 +71,12 @@ _jit_avx512_core_x8s8s32x_1x1_conv_kernel:: use_exact_tail_scalar_bcast}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } if (jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa)) bf16_emu_ = utils::make_unique(this, @@ -230,7 +233,15 @@ template void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::apply_postops( const int load_loop_blk, const int ur, const bool mask_flag_in, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const bool, const int i_load, const int i_ur) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.load_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; apply_sum(load_loop_blk, ur, mask_flag_in, p_sum_scale, p_sum_zp); @@ -278,7 +289,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::apply_postops( jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { @@ -287,7 +298,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::apply_postops( vmm_idxs.emplace( vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -572,6 +583,8 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( Label reduce_loop; Label reduce_loop_tail; + push(reg_oc_off); + mov(aux_reg_load_data, reg_load_data); mov(aux_reg_bcast_data, aux1_reg_bcast_data); @@ -597,6 +610,8 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( fma_block(false); } + pop(reg_oc_off); + if (jcp.oc_without_padding != jcp.oc) { Label end_store, common_store; mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); @@ -626,7 +641,6 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop( template void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { - preamble(); const int simd_w = jcp.ic_block; @@ -675,6 +689,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); if (jcp.ic_block == 4 && jcp.dst_dt == data_type::bf16) { Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32(); @@ -760,6 +775,7 @@ void _jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() { mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out); sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; Label load_loop_blk[7]; @@ -907,6 +923,11 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); jcp.with_sum = sum_ind != -1; + if (jcp.with_sum) + jcp.sum_dt = post_ops.entry_[sum_ind].sum.dt; + + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them @@ -944,7 +965,7 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( static constexpr bool sum_requires_scale_one = false; static constexpr bool sum_requires_zp_zero = false; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp index 6eb52072d8c..8fb71e6b3ff 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp @@ -83,6 +83,14 @@ struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel : public jit_generator { const Xbyak::Opmask k_load_dim_tail_mask = Xbyak::Opmask(4); const Xbyak::Opmask k_load_dim_tail_mask_extended = Xbyak::Opmask(5); const Xbyak::Opmask postops_mask = Xbyak::Opmask(6); + + const Xbyak::Reg64 reg_d_weights = aux_reg_bcast_data; + const Xbyak::Reg64 reg_d_bias = reduce_loop_iter; + const Xbyak::Reg64 reg_oc_off = aux_reg_load_data; + + Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + const Xbyak::Opmask vmask = k7; const Vmm vmm_tmp = Vmm(28); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp index 538a89cf5ee..12e3f326e24 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp @@ -298,6 +298,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = static_cast(p.output_data) - dst_off * dst_dt_size; + p.oc_off = _ocb * jcp.oc_block * sizeof(float); (*kernel_)(&p); }; @@ -423,6 +424,7 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( par_conv_dw.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ocb * jcp_dw->ch_block * sizeof(float); (*kernel_dw_)(&par_conv_dw); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp index 21b5e19fce0..582c91c7033 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp @@ -103,6 +103,8 @@ struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { VDISPATCH_CONV(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_CONV(zero_points_ok(), VERBOSE_UNSUPPORTED_ZP_CFG); + VDISPATCH_CONV(!this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); const convolution_desc_t *conv_d = desc(); const memory_desc_t *src_d = src_md(); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp index a808e692752..918ac3197b4 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -60,7 +60,7 @@ _jit_avx512_core_x8s8s32x_fwd_kernel::_jit_avx512_core_x8s8s32x_fwd_kernel( , jcp(ajcp) , attr_(attr) , postops_injector_(nullptr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -82,9 +82,18 @@ _jit_avx512_core_x8s8s32x_fwd_kernel::_jit_avx512_core_x8s8s32x_fwd_kernel( const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params; + int max_ur_w = nstl::max(jcp.ur_w, jcp.ur_w_tail); + int nb_oc_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + int last_accum_idx = vmm_out(max_ur_w - 1, nb_oc_block - 1).getIdx(); + if (last_accum_idx >= 30) + quantization_static_params = {zmm_d_weights.getIdx(), zmm_d_weights.getIdx(), reg_d_weights, reg_d_bias}; + else + quantization_static_params = {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, static_params, quantization_static_params); } if (!isa_has_bf16(jcp.isa) && jcp.dst_dt == data_type::bf16) bf16_emu_ = utils::make_unique(this, @@ -204,9 +213,18 @@ template void _jit_avx512_core_x8s8s32x_fwd_kernel::apply_postops(int ur_w, bool last_oc_block_flag, const int nb_oc_block, const int oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(nb_oc_block, ur_w, + [&](const bool, const int k, const int j) { + vmm_idx_off.insert({vmm_out_idx(j, k), k * oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {zmm_d_weights.getIdx(), zmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + apply_sum(ur_w, last_oc_block_flag, nb_oc_block, oc_block, p_sum_scale, - p_sum_zp); + p_sum_zp); injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { @@ -230,13 +248,13 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::apply_postops(int ur_w, rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(nb_oc_block, ur_w, [&](const bool, const int k, const int j) { vmm_idxs.emplace(vmm_out_idx(j, k)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } @@ -977,6 +995,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() { int out_shift = jcp.typesize_out * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); preamble(); + bool with_quantization = attr_.post_ops_.find(primitive_kind::quantization) != -1; if (jcp.is_depthwise) { bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point; @@ -988,6 +1007,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() { // due to extra register used for shifts and compensations // and/or saturation, we increment by one more if (jcp.signed_input || jcp.need_saturation) ++idx; + if (with_quantization) ++idx; assert(IMPLICATION(!is_zero_point && jcp.dst_dt != data_type::bf16, idx == ker_dw_reg_base_idx)); @@ -1500,11 +1520,14 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw && jcp.kw < 4 && jcp.dilate_w == 0; + jcp.with_quantization = attr.post_ops_.find(primitive_kind::quantization) != -1; + if (jcp.is_depthwise) { jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise - jcp.signed_input - (!jcp.has_vnni) - (jcp.signed_input || jcp.need_saturation) // both alias - - (bf16_req_extra_regs ? 4 : 0); + - (bf16_req_extra_regs ? 4 : 0) + - jcp.with_quantization; } else { jcp.max_regs_ur = bf16_req_extra_regs ? 26 : jcp.has_vnni ? 31 : 28; } @@ -1602,7 +1625,10 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, const int sum_ind = post_ops.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; - jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; jcp.post_ops = post_ops; @@ -1611,7 +1637,7 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_requires_scale_one = false; static constexpr bool sum_requires_zp_zero = false; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(avx512_core, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); if (!post_ops_ok_) return status::unimplemented; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp index 2a0333317f3..ce75d01cc2f 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp @@ -50,7 +50,7 @@ struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { typename utils::conditional::value, Xbyak::Ymm, Xbyak::Xmm>::type; const int ic_sub_step = 4; - std::unique_ptr> + std::unique_ptr> postops_injector_; enum { @@ -105,6 +105,12 @@ struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { /* binary post-op operand */ const Xbyak::Reg64 temp_offset_reg = r12; + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + + const Xbyak::Zmm zmm_d_weights = Xbyak::Zmm(31); + const Xbyak::Zmm zmm_d_bias = Xbyak::Zmm(30); + const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); const Xbyak::Opmask postops_mask = Xbyak::Opmask(4); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp index 1242fd8e3f8..955cc42c4af 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp @@ -165,6 +165,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); ++start; @@ -333,6 +335,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; @@ -470,6 +474,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g * sizeof(float); (*kernel_)(&p); }); @@ -640,6 +645,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp index 78dad1c5880..305e81211f9 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp @@ -77,6 +77,8 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t { VDISPATCH_CONV(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_CONV(zero_points_ok(), VERBOSE_UNSUPPORTED_ZP_CFG); + VDISPATCH_CONV(!this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); CHECK(jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp index c02bc84bdfe..f65a6ab92ac 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -48,7 +48,7 @@ jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: , attr_(attr) , postops_injector_(nullptr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { const std::size_t tail_size = jcp.is_depthwise ? jcp.ngroups % jcp.ch_block : jcp.oc_without_padding % jcp.oc_block; @@ -65,9 +65,19 @@ jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp {this->param1, rhs_sp}; + if (jcp.has_vnni) { + vmm_d_weights = Vmm(28); + vmm_d_bias = Vmm(29); + } else { + vmm_d_weights = Vmm(26); + vmm_d_bias = Vmm(27); + } + + const quantization_injector::static_params_t qsp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< - injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, bsp); + injector::jit_uni_postops_injector_t>( + this, jcp.post_ops, bsp, qsp); } } @@ -284,6 +294,11 @@ status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( const int prelu_ind = p.find(primitive_kind::prelu); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + const int depthwise_ind = p.find(primitive_kind::depthwise); + jcp.with_depthwise = depthwise_ind != -1; + const int quantization_ind = p.find(primitive_kind::quantization); + jcp.with_quantization = quantization_ind != -1; + const int sum_ind = p.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; @@ -309,7 +324,12 @@ status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( jcp.nb_ic = jcp.ic / jcp.ic_block; /* kernel blocking params */ - const int regs = jcp.has_vnni ? 30 : 28; + int max_regs = jcp.has_vnni ? 30 : 28; + if (jcp.with_depthwise || jcp.with_quantization) { + max_regs -= 2; + } + + const int regs = max_regs; jcp.nb_ch_blocking = 1; jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) @@ -379,7 +399,7 @@ bool _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( static constexpr bool sum_requires_scale_one = false; return injector::post_ops_ok( - post_ops_ok_args_t(avx512_core, {eltwise, binary, sum}, post_ops, + post_ops_ok_args_t(avx512_core, {eltwise, binary, sum, depthwise, quantization}, post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one)); } @@ -986,7 +1006,7 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( } } /* Do post-ops */ - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { const auto &p = attr_.post_ops_; const int sum_idx = p.find(primitive_kind::sum); const float *p_sum_scale @@ -1042,10 +1062,21 @@ void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( } } } + + std::map vmm_idx_off; + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + vmm_idx_off.insert({vmm_out(ur, ocb).getIdx(), ocb * jcp.oc_block * sizeof(float)}); + } + } + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + const int nb_oc_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; postops_injector_->compute_vector_range( - 0, nb_oc_block * ur_w, rhs_arg_params); + 0, nb_oc_block * ur_w, rhs_arg_params, ddp, qdp); } if (jcp.dst_scale) { @@ -1489,6 +1520,7 @@ status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d( p.kh_padding = jcp.kh; p.oc_blocks = jcp.is_depthwise ? g : ocb; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + p.oc_off = g_oc * sizeof(float); p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.zp_src_pad_str_compensation @@ -1654,6 +1686,7 @@ status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d( p.oc_blocks = jcp.is_depthwise ? g : ocb; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + p.oc_off = g_oc * sizeof(float); p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.zp_src_pad_str_compensation = jcp.src_zero_point @@ -1877,6 +1910,7 @@ status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_3d( p.oc_blocks = jcp.is_depthwise ? g : ocb; p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); + p.oc_off = g_oc * sizeof(float); p.zp_compensation = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; p.zp_src_pad_str_compensation = jcp.src_zero_point diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp index 5c0475f637b..d882d4523f2 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp @@ -73,7 +73,7 @@ struct ur_w_blks_params_t { template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_kernel); jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md); @@ -83,7 +83,7 @@ struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { const primitive_attr_t &attr_; private: - std::unique_ptr> + std::unique_ptr> postops_injector_; const int ic_sub_step = 4; @@ -138,6 +138,12 @@ struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { const Vmm vmm_dst_scale = Vmm(31); const Vmm vmm_prev_dst = Vmm(31); + /* depthwise and quantization post ops */ + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + Vmm vmm_d_weights; + Vmm vmm_d_bias; + Vmm vmm_out(int i_ur, int i_oc) { int idx = i_ur * jcp.nb_oc_blocking + i_oc; assert(idx < 31); diff --git a/src/cpu/x64/jit_gemm_inner_product_utils.cpp b/src/cpu/x64/jit_gemm_inner_product_utils.cpp index 52854681668..8cedf2d3842 100644 --- a/src/cpu/x64/jit_gemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_gemm_inner_product_utils.cpp @@ -1180,7 +1180,7 @@ void jit_pp_kernel_t::generate() { && (this->OC_ <= vlen / 2) && (this->MB_ >= vlen); bool supported_postops = this->do_scale_ || this->do_eltwise_ || this->do_binary_ || this->do_prelu_ || this->do_sum_ - || this->do_dst_zero_points_ || this->do_dst_scale_; + || this->do_dst_zero_points_ || this->do_dst_scale_ || (this->post_ops_.len() > 0); if (this->do_bias() && !supported_postops && dim_restrict && this->has_trivial_mb_stride()) { this->mb_blk_kernel_ = true; diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp index b9b1773db3c..b46888ad761 100644 --- a/src/cpu/x64/jit_primitive_conf.hpp +++ b/src/cpu/x64/jit_primitive_conf.hpp @@ -102,6 +102,8 @@ struct jit_conv_conf_t { bool with_sum; bool with_eltwise; bool with_binary; + bool with_depthwise; + bool with_quantization; data_type_t sum_dt; @@ -346,6 +348,9 @@ struct jit_conv_call_s { int oc_flag; size_t last_ic_block; size_t last_oc_block; + + size_t oc_off; + size_t oc_off_prf; }; struct jit_deconv_call_s { @@ -374,6 +379,7 @@ struct jit_deconv_call_s { size_t kh_padding; size_t kd_padding; size_t oc_blocks; + size_t oc_off; }; struct jit_dw_conv_call_s { @@ -405,6 +411,8 @@ struct jit_1x1_conv_conf_t { bool with_sum; bool with_eltwise; bool with_binary; + bool with_depthwise; + bool with_quantization; bool with_dw_conv; post_ops_t post_ops; @@ -483,6 +491,8 @@ struct jit_1x1_conv_call_s { size_t output_stride; // used in backward_weights only size_t first_last_flag; + + size_t oc_off; }; struct jit_pool_conf_t { @@ -524,6 +534,8 @@ struct jit_pool_conf_t { bool with_postops; bool with_eltwise; bool with_binary; + bool with_depthwise; + bool with_quantization; int nthr; memory_desc_t tmp_md; }; diff --git a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp index bfcad42bfb5..cee4646667b 100644 --- a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp @@ -43,7 +43,7 @@ jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32( const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), sse41), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; static constexpr size_t helper_vmm_idx = 15; @@ -57,9 +57,12 @@ jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32( use_exact_tail_scalar_bcast}; const binary_injector::static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -130,6 +133,15 @@ static void iterate(const int load_loop_blk, const int ur, const F &f) { } void jit_sse41_1x1_conv_kernel_f32::apply_postops( const int load_loop_blk, const int ur) { + std::map vmm_idx_off; + iterate(load_loop_blk, ur, + [&](const int i, const int j, const int n) { + vmm_idx_off.insert({reg_accum_idx(load_loop_blk, i, j, n), (2 * i + n) * jcp.load_block / 2 * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -158,12 +170,12 @@ void jit_sse41_1x1_conv_kernel_f32::apply_postops( mov(abi_param1, ptr[rsp + reg_abi_param1_backup + reg_guard_stack_occupied]); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(load_loop_blk, ur, [&](const int i, const int j, const int n) { vmm_idxs.emplace(reg_accum_idx(load_loop_blk, i, j, n)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -286,7 +298,7 @@ void jit_sse41_1x1_conv_kernel_f32::generate_reduce_loop( L(store_noadd); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { assert(ur * load_loop_blk < 14); Label store_nopostops; @@ -460,6 +472,7 @@ void jit_sse41_1x1_conv_kernel_f32::generate() { mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); if (jcp.prop_kind == backward_weights) mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto generate_load_loop_body = [&](int load_loop_blk) { const size_t offst_with_dw_conv @@ -492,6 +505,7 @@ void jit_sse41_1x1_conv_kernel_f32::generate() { default: assert(!"invalid prop_kind"); } sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; Label load_loop_blk_8; @@ -607,6 +621,9 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, const int prelu_ind = post_ops.find(primitive_kind::prelu, 0, dw_conv_ind); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; + if (dw_conv_ind >= 0) { // dw_conv and post_ops after it are handled externally, so skip them jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), @@ -619,8 +636,9 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; + const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(sse41, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp index d8061777cdc..8c2350f7f09 100644 --- a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp @@ -58,12 +58,12 @@ struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator { reg64_t reg_output_data = rbx; reg64_t aux_reg_bcast_data = rdx; reg64_t aux1_reg_bcast_data = abi_not_param1; - reg64_t aux_reg_load_data = abi_param1; reg64_t aux_reg_output_data = rbp; reg64_t reg_load_loop_work = r9; reg64_t reg_bcast_loop_work = r10; reg64_t reg_reduce_loop_work = r11; reg64_t load_loop_iter = r13; + reg64_t aux_reg_load_data = load_loop_iter; reg64_t imm_addr64 = load_loop_iter; reg64_t bcast_loop_iter = r14; reg64_t reduce_loop_iter = r15; @@ -83,6 +83,13 @@ struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator { std::unique_ptr> postops_injector_; + reg64_t reg_oc_off = abi_param1; + reg64_t reg_d_weights = aux_reg_bcast_data; + reg64_t reg_d_bias = reduce_loop_iter; + + Xbyak::Xmm xmm_d_weights = Xbyak::Xmm(14); + Xbyak::Xmm xmm_d_bias = Xbyak::Xmm(15); + void generate_bcast_loop(int load_loop_blk); void generate_reduce_loop(int load_loop_blk, int ur); void generate_diff_bias_loop(int load_loop_blk); diff --git a/src/cpu/x64/jit_sse41_1x1_convolution.cpp b/src/cpu/x64/jit_sse41_1x1_convolution.cpp index 52b68bbfe9a..ddd50624202 100644 --- a/src/cpu/x64/jit_sse41_1x1_convolution.cpp +++ b/src/cpu/x64/jit_sse41_1x1_convolution.cpp @@ -182,6 +182,7 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; par_conv.dst_orig = static_cast(par_conv.output_data) - dst_off; + par_conv.oc_off = _ocb * jcp.oc_block * sizeof(float); (*kernel_)(&par_conv); }; @@ -261,6 +262,8 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + par_conv_dw.oc_off = ch * jcp_dw.ch_block * sizeof(float); + (*kernel_dw_)(&par_conv_dw); for (int i = 0; i < jcp_dw.kh; ++i) diff --git a/src/cpu/x64/jit_sse41_1x1_convolution.hpp b/src/cpu/x64/jit_sse41_1x1_convolution.hpp index fedc8560c87..b5650ad4e3c 100644 --- a/src/cpu/x64/jit_sse41_1x1_convolution.hpp +++ b/src/cpu/x64/jit_sse41_1x1_convolution.hpp @@ -67,6 +67,9 @@ struct jit_sse41_1x1_convolution_fwd_t : public primitive_t { VDISPATCH_CONV( attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV( + !this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); CHECK(jit_sse41_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), *attr(), @@ -274,7 +277,7 @@ struct jit_sse41_1x1_convolution_fwd_t : public primitive_t { if (pd()->jcp_.with_dw_conv) { CHECK(safe_ptr_assign(kernel_dw_, new dw_conv_kernel_t( - pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0)))); + pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0), *pd()->dw_conv_pd_->attr()))); return kernel_dw_->create_kernel(); } diff --git a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp index f9421477224..be82728fae2 100644 --- a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp @@ -43,7 +43,7 @@ jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), sse41), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; static constexpr size_t helper_vmm_idx = 15; @@ -57,10 +57,12 @@ jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32( use_exact_tail_scalar_bcast}; const binary_injector::static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -158,6 +160,14 @@ static void iterate(const int oc_blocks, const int ur_w, const F &f) { } void jit_sse41_conv_fwd_kernel_f32::apply_postops( const int oc_blocks, const int ur_w) { + std::map vmm_idx_off; + iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { + vmm_idx_off.insert({get_xmm_idx(ur_w, i, j), i * jcp.oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {xmm_d_weights.getIdx(), xmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; @@ -175,12 +185,12 @@ void jit_sse41_conv_fwd_kernel_f32::apply_postops( rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } else { iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) { vmm_idxs.emplace(get_xmm_idx(ur_w, i, j)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } @@ -266,7 +276,7 @@ void jit_sse41_conv_fwd_kernel_f32::width_blk_step( L(skip_kh_loop); - if (jcp.with_eltwise || jcp.with_binary) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { Label regular_store; test(reg_ci_flag, FLAG_IC_LAST); je(regular_store, T_NEAR); @@ -288,12 +298,15 @@ void jit_sse41_conv_fwd_kernel_f32::width_blk_step( add(aux_reg_kernel, sizeof(float) * 4); add(reg_output, sizeof(float) * 4); add(reg_bias, sizeof(float) * 4); + add(reg_oc_off, sizeof(float) * 4); + inc(simd_iter); cmp(simd_iter, 2); jl(init_simd_iter_loop, T_NEAR); sub(reg_output, sizeof(float) * 8); sub(reg_bias, sizeof(float) * 8); + sub(reg_oc_off, sizeof(float) * 8); } inline void jit_sse41_conv_fwd_kernel_f32::solve_common(int oc_blocks) { @@ -355,6 +368,7 @@ void jit_sse41_conv_fwd_kernel_f32::generate() { mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; Label tail, exit; @@ -458,6 +472,8 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, const int binary_ind = post_ops.find(primitive_kind::binary); const int prelu_ind = post_ops.find(primitive_kind::prelu); jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -465,8 +481,9 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; static constexpr bool sum_requires_zp_zero = true; + const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(sse41, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, sum_at_pos_0_only, + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp b/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp index 42f87bcceae..ee041776c4e 100644 --- a/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_sse41_conv_kernel_f32.hpp @@ -70,6 +70,13 @@ struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator { std::unique_ptr> postops_injector_; + reg64_t reg_d_weights = imm_addr64; + reg64_t reg_d_bias = ki_iter; + reg64_t reg_oc_off = abi_param1; + + Xbyak::Xmm xmm_d_weights = Xbyak::Xmm(14); + Xbyak::Xmm xmm_d_bias = Xbyak::Xmm(15); + inline void oh_step_unroll_kw( int ur_w, int pad_l, int pad_r, int oc_blocks); inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); diff --git a/src/cpu/x64/jit_sse41_convolution.cpp b/src/cpu/x64/jit_sse41_convolution.cpp index af0523262de..9a123a2317a 100644 --- a/src/cpu/x64/jit_sse41_convolution.cpp +++ b/src/cpu/x64/jit_sse41_convolution.cpp @@ -126,7 +126,7 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.flags |= FLAG_IC_FIRST; } - if ((jcp.with_eltwise || jcp.with_binary) + if ((jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) && icb + 1 == jcp.nb_ic) { par_conv.flags |= FLAG_IC_LAST; } @@ -143,6 +143,7 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); par_conv.dst_orig = dst; + par_conv.oc_off = _oc * (is_dst_layout_nxc ? 1 : jcp.oc_block) * sizeof(float); (*kernel_)(&par_conv); } diff --git a/src/cpu/x64/jit_sse41_convolution.hpp b/src/cpu/x64/jit_sse41_convolution.hpp index 6a51f32edf3..1bf4734c113 100644 --- a/src/cpu/x64/jit_sse41_convolution.hpp +++ b/src/cpu/x64/jit_sse41_convolution.hpp @@ -57,6 +57,9 @@ struct jit_sse41_convolution_fwd_t : public primitive_t { VDISPATCH_CONV( attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV( + !this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); CHECK(jit_sse41_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), *attr(), diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp index 290cba2d4e2..388ccf45ea0 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp @@ -37,9 +37,9 @@ using namespace Xbyak; template jit_uni_dw_conv_fwd_kernel_f32::jit_uni_dw_conv_fwd_kernel_f32( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) - : jit_generator(jit_name(), isa), jcp(ajcp) { - if (jcp.with_eltwise || jcp.with_binary) { + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr) + : jit_generator(jit_name(), isa), jcp(ajcp), attr_(attr) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = false; @@ -53,10 +53,12 @@ jit_uni_dw_conv_fwd_kernel_f32::jit_uni_dw_conv_fwd_kernel_f32( memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -272,8 +274,19 @@ void iterate( template void jit_uni_dw_conv_fwd_kernel_f32::apply_postops( const int ur_ch_blocks, const int ur_w, const bool is_ch_tail) { - if (this->jcp.with_eltwise || this->jcp.with_binary) { + if (this->jcp.with_eltwise || this->jcp.with_binary || this->jcp.with_depthwise || this->jcp.with_quantization) { const int repeats = max_repeats(); + + std::map vmm_idx_off; + iterate(repeats, ur_ch_blocks, ur_w, + [&](const int r, const int ch, const int ow, const bool) { + vmm_idx_off.insert({get_acc_reg_idx(r * ur_ch_blocks * ur_w + ch * ur_w + ow), (ch * repeats + r) * jcp.ch_block / repeats * sizeof(float)}); + }); + + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + injector_utils::vmm_index_set_t vmm_idxs; if (jcp.with_binary) { binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, @@ -321,16 +334,16 @@ void jit_uni_dw_conv_fwd_kernel_f32::apply_postops( cmp(reg_tmp, jcp.nb_ch_blocking * jcp.ch_block); jge(postops_no_tail, T_NEAR); postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params_tail); + vmm_idxs, rhs_arg_params_tail, ddp, qdp); jmp(postops_done, T_NEAR); L(postops_no_tail); } else if (is_ch_tail) { postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params_tail); + vmm_idxs, rhs_arg_params_tail, ddp, qdp); } if (!is_ch_tail) { postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params); + vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } } else { @@ -339,7 +352,7 @@ void jit_uni_dw_conv_fwd_kernel_f32::apply_postops( vmm_idxs.emplace(get_acc_reg_idx( r * ur_ch_blocks * ur_w + ch * ur_w + ow)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t(), ddp, qdp); } } } diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp index 1a152cec9e7..32345429f8d 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.hpp @@ -36,9 +36,10 @@ struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) jit_uni_dw_conv_fwd_kernel_f32( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md); + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr); jit_conv_conf_t jcp; + const primitive_attr_t &attr_; private: using Vmm = typename utils::conditional3::init_conf( broadcasting_strategy_t::per_oc, broadcasting_strategy_t::no_broadcast); } + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; @@ -231,7 +233,7 @@ status_t jit_uni_dw_conv_fwd_kernel::init_conf( static constexpr bool sum_at_pos_0_only = true; static constexpr bool sum_requires_scale_one = true; const bool post_ops_ok_ = post_ops_ok( - post_ops_ok_args_t(isa, {eltwise, binary, sum}, jcp.post_ops, + post_ops_ok_args_t(isa, {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp b/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp index 7f5902e055b..df7a57914a1 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_utils.hpp @@ -39,8 +39,8 @@ template struct jit_uni_dw_conv_fwd_kernel { jit_uni_dw_conv_fwd_kernel( - const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) - : ker_(utils::make_unique(ajcp, dst_md)) {} + const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md, const primitive_attr_t &attr) + : ker_(utils::make_unique(ajcp, dst_md, attr)) {} status_t create_kernel() { if (ker_) return ker_->create_kernel(); diff --git a/src/cpu/x64/jit_uni_dw_convolution.cpp b/src/cpu/x64/jit_uni_dw_convolution.cpp index 700e122805b..28d0f53a7ed 100644 --- a/src/cpu/x64/jit_uni_dw_convolution.cpp +++ b/src/cpu/x64/jit_uni_dw_convolution.cpp @@ -143,6 +143,8 @@ void jit_uni_dw_convolution_fwd_t::execute_forward( par_conv.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); par_conv.dst_orig = dst; + par_conv.oc_off = ch * jcp.ch_block * sizeof(float); + (*kernel_)(&par_conv); if (jcp.loop_order == loop_ngcw) { diff --git a/src/cpu/x64/jit_uni_dw_convolution.hpp b/src/cpu/x64/jit_uni_dw_convolution.hpp index 83a1f1d6ec5..4a0eccfd8ec 100644 --- a/src/cpu/x64/jit_uni_dw_convolution.hpp +++ b/src/cpu/x64/jit_uni_dw_convolution.hpp @@ -60,6 +60,9 @@ struct jit_uni_dw_convolution_fwd_t : public primitive_t { utils::one_of(this->desc()->bias_desc.data_type, f32, bf16)), VERBOSE_UNSUPPORTED_BIAS_CFG); + VDISPATCH_CONV( + !this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); auto status = jit_uni_dw_conv_fwd_kernel::init_conf( jcp_, *desc(), src_md_, weights_md_, bias_md_, dst_md_, @@ -86,7 +89,7 @@ struct jit_uni_dw_convolution_fwd_t : public primitive_t { status_t init(engine_t *engine) override { CHECK(safe_ptr_assign(kernel_, new jit_uni_dw_conv_fwd_kernel( - pd()->jcp_, *pd()->dst_md(0)))); + pd()->jcp_, *pd()->dst_md(0), *pd()->attr()))); return kernel_->create_kernel(); } diff --git a/src/cpu/x64/jit_uni_i8i8_pooling.cpp b/src/cpu/x64/jit_uni_i8i8_pooling.cpp index 5d60ff176e5..fe15756927b 100644 --- a/src/cpu/x64/jit_uni_i8i8_pooling.cpp +++ b/src/cpu/x64/jit_uni_i8i8_pooling.cpp @@ -97,12 +97,17 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { Reg64 aux_reg_src_h = rax; Reg64 aux_reg_src_w = rbx; + Reg64 reg_store_tmp = r11; // shared with reg_kh_index and used only as tmp register for store on avx2 Reg64 reg_tmp = rdx; // only used during mask init and store Reg64 reg_src_safe_access = rbp; Reg64 reg_dst_safe_access = rsi; Reg64 reg_mask = r15; // only used during mask init + Reg64 reg_oc_off = reg_tmp; + Reg64 reg_d_weights = aux_reg_src_h; + Reg64 reg_d_bias = aux_reg_src_w; + Opmask k_cmp_mask = Opmask(7); Opmask mask(int idx) { return Opmask(6 - idx); } @@ -139,6 +144,9 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { std::unique_ptr> postops_injector_; + Vmm vmm_d_weights = vreg(3); + Vmm vmm_d_bias = vreg(4); + enum : int { max_vidx_base = utils::one_of(isa, sse41, avx2) ? 7 : 2 }; //"avg" pool uses more registers for unrolling. enum : int { avg_vidx_base = utils::one_of(isa, sse41, avx2) ? 4 : 2 }; @@ -263,10 +271,12 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp { reg_param, get_supported_bcast_strategies(), rhs_sp}; + quantization_injector::static_params_t qsp = + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jpp.post_ops, bsp); + this, jpp.post_ops, bsp, qsp); } } }; @@ -659,18 +669,18 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( // Don't generate useless code if (masked && !msk) return; - const Vmm &vr_dst = vreg_dst_s32(jj, ll); + const Vmm &vr_dst = jpp.dst_dt == f32 ? vreg_dst_f32(jj, ll) : vreg_dst_s32(jj, ll); - if (jpp.src_dt == s32) { + if (jpp.dst_dt == s32 || jpp.dst_dt == f32) { if (masked) for (int i = 0; i < jpp.c_tail; i++) pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)], vr_dst, i); else movups(ptr[reg_ptr_dst_i8 + offset], vr_dst); - } else if (utils::one_of(jpp.src_dt, s8, u8)) { + } else if (utils::one_of(jpp.dst_dt, s8, u8)) { packssdw(vr_dst, vr_dst); - if (jpp.src_dt == s8) + if (jpp.dst_dt == s8) packsswb(vr_dst, vr_dst); else packuswb(vr_dst, vr_dst); @@ -729,8 +739,8 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( // maskmovdqu/vmaskmovdqu // with low 8-bytes mask throws exception if high 8-bytes belongs write-protected page. // NOTE: use indirect move via gpr to avoid transition penalty - vmovq(reg_tmp, Xmm(vr_dst.getIdx())); - movq(mmx_dst_i8, reg_tmp); + vmovq(reg_store_tmp, Xmm(vr_dst.getIdx())); + movq(mmx_dst_i8, reg_store_tmp); // mmx_full_msk - mask for all 8 bytes in zero-tail case // mmx_mask(ll) - ll-th mask of tail in non-zero-tail case @@ -771,6 +781,17 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( }; switch (jpp.dst_dt) { + case f32: + if (masked) { + if (sizeof_src_dt() != sizeof_dst_dt()) { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask_2, vreg_dst_f32(jj, ll)); + } else { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_f32(jj, ll)); + } + } else { + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_f32(jj, ll)); + } + break; case s32: if (masked) { vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, @@ -792,11 +813,11 @@ void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op( // Don't generate useless code if (masked && !msk) return; - const Vmm &vr_dst - = masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll); + const Vmm &vr_dst = jpp.dst_dt == f32 ? masked ? vreg_dst_f32(jj, ll) | mask(ll) : vreg_dst_f32(jj, ll) + : masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll); switch (jpp.dst_dt) { - case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; + case f32: case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; case s8: vpmovsdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; case u8: vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; default: assert(!"unsupported dst data_type"); @@ -936,7 +957,7 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( int iw = jpp.iw; int c = jpp.c; - const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.src_dt); + const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.dst_dt); for (int jj = 0; jj < ur_c; jj++) { for (int ll = 0; ll < num_ll; ll++) { @@ -950,6 +971,9 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( } } + if (jpp.with_depthwise || jpp.with_quantization) + push(reg_oc_off); + mov(aux_reg_src_d, reg_ptr_src_i8); xor_(reg_kd_index, reg_kd_index); L(l_kd); @@ -989,6 +1013,11 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( jl(l_kd, T_NEAR); } + static constexpr int vlen_size_elem = cpu_isa_traits::vlen / sizeof(float); + + if (jpp.with_depthwise || jpp.with_quantization) + pop(reg_oc_off); + for (int jj = 0; jj < ur_c; jj++) { for (int ll = 0; ll < num_ll; ll++) { const bool masked = jj == ur_c - 1 && c_tail; @@ -1000,6 +1029,15 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( uni_vfmadd132ps(reg_dst_f32, vreg_zeros, vreg_tmp); if (jpp.with_postops) { + std::map vmm_idx_off; + vmm_idx_off.insert({reg_dst_f32.getIdx(), (ll * vlen_size_elem + jj * vlen_size_elem) * sizeof(float)}); + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; + + injector_utils::vmm_index_set_t vmm_idxs; + vmm_idxs.emplace(reg_dst_f32.getIdx()); + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; if (jpp.with_binary) { rhs_arg_params.vmm_idx_to_out_reg.emplace( @@ -1011,16 +1049,19 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step( rhs_arg_params.vmm_tail_idx_.emplace( reg_dst_f32.getIdx()); } - postops_injector_->compute_vector( - reg_dst_f32.getIdx(), rhs_arg_params); + postops_injector_->compute_vector_range( + vmm_idxs, rhs_arg_params, ddp, qdp); } - uni_vcvtps2dq(reg_dst_s32, reg_dst_f32); + if (jpp.dst_dt != f32) { + uni_vcvtps2dq(reg_dst_s32, reg_dst_f32); + } if (jpp.with_postops) if (jpp.dst_dt == u8) { uni_vpmaxsd(reg_dst_s32, reg_dst_s32, vreg_zeros); } + store_dst(jj, ll, c_tail); } } @@ -1049,12 +1090,17 @@ void jit_uni_i8i8_pooling_fwd_ker_t::compute_c_block() { int c_tail = jpp.c_tail; xor_(c_iter, c_iter); + if (jpp.with_quantization) + xor_(reg_oc_off, reg_oc_off); + if (c_steps > 0) { L(l_main_loop); { compute_step(ur_c, 0); add(reg_ptr_src_i8, ur_c * c_block * sizeof_src_dt()); add(reg_ptr_dst_i8, ur_c * c_block * sizeof_dst_dt()); + if (jpp.with_quantization) + add(reg_oc_off, ur_c*c_block*sizeof(float)); inc(c_iter); cmp(c_iter, c_steps); jl(l_main_loop, T_NEAR); @@ -1130,6 +1176,10 @@ void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { vpalignr(vreg_mask_2, vreg_mask_2, vreg_zeros, 32 - shift); } vextracti128(xreg_mask_2_hi, vreg_mask_2, 0x1); + + if (sizeof_src_dt() != sizeof_dst_dt()) { + vpmovsxbd(vreg_mask_2, vreg_mask); + } } // Need mask in MMX regs ? @@ -1322,7 +1372,7 @@ status_t jit_uni_i8i8_pooling_fwd_ker_t::init_conf( // isa == sse41 : 16 bytes -> 16 for s8/u8, 4 for s32 // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 - int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.src_dt); + int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.dst_dt); /* Verify that vlen-sized memory access happens within the tensor's * size, otherwise load/store will always spill outside the memory @@ -1385,6 +1435,8 @@ bool jit_uni_i8i8_pooling_fwd_ker_t::post_ops_ok(jit_pool_conf_t &jpp, jpp.with_postops = false; jpp.with_eltwise = false; jpp.with_binary = false; + jpp.with_depthwise = false; + jpp.with_quantization = false; if (entries.empty()) return true; @@ -1398,11 +1450,16 @@ bool jit_uni_i8i8_pooling_fwd_ker_t::post_ops_ok(jit_pool_conf_t &jpp, && entry.binary.src1_desc.data_type == data_type::bf16) return false; jpp.with_binary = true; - } else + } else if (entry.is_depthwise()) { + jpp.with_depthwise = true; + } else if (entry.is_quantization()) { + jpp.with_quantization = true; + } else { return false; + } } - jpp.with_postops = jpp.with_eltwise || jpp.with_binary; + jpp.with_postops = jpp.with_eltwise || jpp.with_binary || jpp.with_depthwise || jpp.with_quantization; jpp.post_ops = post_ops; /* diff --git a/src/cpu/x64/jit_uni_i8i8_pooling.hpp b/src/cpu/x64/jit_uni_i8i8_pooling.hpp index 78708c32d18..fbf5ec09eec 100644 --- a/src/cpu/x64/jit_uni_i8i8_pooling.hpp +++ b/src/cpu/x64/jit_uni_i8i8_pooling.hpp @@ -79,6 +79,10 @@ struct jit_uni_i8i8_pooling_fwd_t : public primitive_t { VDISPATCH_POOLING( attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_POOLING(IMPLICATION( + utils::one_of(desc()->alg_kind, alg_kind::pooling_avg_include_padding, alg_kind::pooling_avg_exclude_padding), + utils::one_of(dst_md()->data_type, data_type::u8, data_type::s8, data_type::f32)), + VERBOSE_BAD_ALGORITHM); CHECK(jit_conf()); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp index 7976522e6c6..c6a5bc93618 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp @@ -49,7 +49,7 @@ _jit_uni_x8s8s32x_1x1_conv_kernel::_jit_uni_x8s8s32x_1x1_conv_kernel( const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), isa), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = true; @@ -61,10 +61,12 @@ _jit_uni_x8s8s32x_1x1_conv_kernel::_jit_uni_x8s8s32x_1x1_conv_kernel( memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; static_params_t static_params {this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -190,7 +192,15 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::apply_postops(const int ur, const int load_loop_blk, const bool mask_flag_in, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) { + vmm_idx_off.insert({vreg_accum_idx(load_loop_blk, i_load, i_ur), i_load * jcp.load_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + reg_oc_off, vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {reg_oc_off, vmm_idx_off}; + if (jcp.with_sum && *p_sum_zp != 0) mov(ptr[rsp + reg_bcast_loop_iter_off], reg_ptr_sum_zp); apply_sum(ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp); @@ -223,17 +233,17 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::apply_postops(const int ur, test(reg_reduce_pos_flag, FLAG_OC_LAST); je(postops_no_tail, T_NEAR); postops_injector_->compute_vector_range( - vmm_idxs, rhs_arg_params_tail); + vmm_idxs, rhs_arg_params_tail, ddp, qdp); jmp(postops_done, T_NEAR); L(postops_no_tail); } - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); L(postops_done); } else { iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) { vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i_load, i_ur)); }); - postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } if (jcp.with_sum && *p_sum_zp != 0) mov(reg_ptr_sum_zp, ptr[rsp + reg_bcast_loop_iter_off]); @@ -462,6 +472,8 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::reduce_loop( Label reduce_loop; Label reduce_loop_tail; + push(reg_oc_off); + mov(aux_reg_load_data, reg_load_data); mov(aux_reg_bcast_data, aux1_reg_bcast_data); @@ -483,6 +495,8 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::reduce_loop( L(reduce_loop_tail); fma_block(jcp.ic != jcp.ic_without_padding); + pop(reg_oc_off); + if (jcp.oc_without_padding != jcp.oc) { Label end_store, common_store; mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data); @@ -548,6 +562,7 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { mov(ptr[rsp + bcast_loop_work_off], reg_bcast_loop_work); mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); auto load_loop_body = [&](int load_loop_blk) { bcast_loop(load_loop_blk); @@ -581,6 +596,7 @@ void _jit_uni_x8s8s32x_1x1_conv_kernel::generate() { mov(reg_bcast_data, ptr[rsp + reg_bcast_data_off]); add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out); sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float)); }; static const int ur_cases[] = {2, 3, 5, 12}; @@ -695,6 +711,9 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); jcp.with_sum = sum_ind != -1; + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise, 0, dw_conv_ind) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization, 0, dw_conv_ind) != -1; + const auto zp = attr.zero_points_; jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); @@ -730,7 +749,7 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( using namespace injector; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(isa, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, false, false, false)); + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, false, false, false)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); args_ok = true && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 @@ -742,7 +761,8 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; jcp.dst_dt = cd.dst_desc.data_type; - jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); jcp.ic_block = jcp.oc_block = simd_w; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp index 204d00d0bc5..5d14f8f7b27 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp @@ -35,7 +35,6 @@ struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_1x1_conv_kernel) _jit_uni_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md); - int get_tail_size() { return jcp.oc_without_padding % jcp.oc_block; } jit_1x1_conv_conf_t jcp; @@ -63,7 +62,7 @@ struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator { const Xbyak::Reg64 reg_reduce_loop_iter = r13; const Xbyak::Reg64 aux_reg_bcast_data = r14; const Xbyak::Reg64 aux_reg_load_data = r15; - const Xbyak::Reg64 aux_reg_saturation = r15; + const Xbyak::Reg64 aux_reg_saturation = r14; const Xbyak::Reg64 reg_reduce_pos_flag = rax; const Xbyak::Reg64 aux1_reg_bcast_data = rbx; const Xbyak::Reg64 reg_bcast_loop_work = rbx; @@ -75,6 +74,13 @@ struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator { const Xbyak::Reg64 reg_src_zero_point = aux_reg_bcast_data; // r14 const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point; + const Xbyak::Reg64 reg_d_weights = aux_reg_bcast_data; + const Xbyak::Reg64 reg_d_bias = abi_param1; + const Xbyak::Reg64 reg_oc_off = aux_reg_load_data; + + Vmm vmm_d_weights = Vmm(0); + Vmm vmm_d_bias = Vmm(1); + const Vmm vmm_tmp = Vmm(3); const Vmm vmm_one = Vmm(2); const Vmm vmm_zero = Vmm(1); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp index 6b0e7e9d9e9..6485aa3bd81 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp @@ -291,6 +291,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; p.dst_orig = static_cast(p.output_data) - dst_offset * dst_dt_size; + p.oc_off = _ocb * jcp.oc_block * sizeof(float); (*kernel_)(&p); }; @@ -420,6 +421,7 @@ void jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( par_conv_dw.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec_dw; par_conv_dw.dst_orig = dst; + p.oc_off = ocb * jcp_dw->ch_block * sizeof(float); (*kernel_dw_)(&par_conv_dw); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp index c86b8c75052..bc040e7e255 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp @@ -102,6 +102,9 @@ struct jit_uni_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { VDISPATCH_CONV( attr_.set_default_formats(dst_md(0)) == status::success, VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CONV( + !this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); const convolution_desc_t *conv_d = desc(); const memory_desc_t *src_d = src_md(); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp index d4512765ee3..c4d9742cd87 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp @@ -58,11 +58,11 @@ _jit_uni_x8s8s32x_fwd_kernel::_jit_uni_x8s8s32x_fwd_kernel( const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, const memory_desc_t &dst_md) : jit_generator(jit_name(), isa), jcp(ajcp), attr_(attr) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { using namespace binary_injector; static constexpr bool preserve_gpr = true; - static constexpr bool preserve_vmm = false; - static constexpr size_t helper_vmm_idx = 15; + static constexpr bool preserve_vmm = true; + static constexpr size_t helper_vmm_idx = 2; const size_t block_tail = (jcp.is_depthwise ? jcp.ch_block : jcp.oc_block) % isa_simd_width_; @@ -77,10 +77,12 @@ _jit_uni_x8s8s32x_fwd_kernel::_jit_uni_x8s8s32x_fwd_kernel( memory_desc_wrapper(dst_md), tail_size, true}; const static_params_t static_params { this->param1, rhs_arg_static_params}; + quantization_injector::static_params_t quantization_static_params = + {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; postops_injector_ = utils::make_unique>( - this, jcp.post_ops, static_params); + this, jcp.post_ops, static_params, quantization_static_params); } } @@ -180,14 +182,23 @@ template void _jit_uni_x8s8s32x_fwd_kernel::apply_postops( const int nb_oc_block, const int ur_w, const bool last_oc_block_flag, const int oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) { - if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { + if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum || jcp.with_depthwise || jcp.with_quantization) { + std::map vmm_idx_off; + iterate(nb_oc_block, ur_w, + [&](const bool, const int k, const int j) { + vmm_idx_off.insert({vmm_out_idx(j, k), k * oc_block * sizeof(float)}); + }); + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + if (jcp.with_sum && *p_sum_zp != 0) push(reg_ptr_sum_zp); apply_sum(nb_oc_block, ur_w, last_oc_block_flag, oc_block, p_sum_scale, p_sum_zp); vmm_index_set_t vmm_idxs; + binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; if (jcp.with_binary) { - binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; const bool oc_blk_is_smaller_than_vmm = oc_block < isa_simd_width_; iterate(nb_oc_block, ur_w, last_oc_block_flag, oc_blk_is_smaller_than_vmm, @@ -214,7 +225,7 @@ void _jit_uni_x8s8s32x_fwd_kernel::apply_postops( [&](const bool, const int k, const int j) { vmm_idxs.emplace(vmm_out_idx(j, k)); }); - postops_injector_->compute_vector_range(vmm_idxs); + postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params, ddp, qdp); } if (jcp.with_sum && *p_sum_zp != 0) pop(reg_ptr_sum_zp); } @@ -1469,14 +1480,18 @@ status_t jit_uni_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, jcp.with_binary = !everyone_is(-1, binary_ind, prelu_ind); const int sum_ind = post_ops.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; - jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + if (jcp.with_sum) + jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); + + jcp.with_depthwise = post_ops.find(primitive_kind::depthwise) != -1; + jcp.with_quantization = post_ops.find(primitive_kind::quantization) != -1; jcp.post_ops = post_ops; using namespace injector; const bool post_ops_ok_ = post_ops_ok(post_ops_ok_args_t(isa, - {eltwise, binary, sum}, jcp.post_ops, &dst_d, false, false, false)); + {eltwise, binary, sum, depthwise, quantization}, jcp.post_ops, &dst_d, false, false, false)); VDISPATCH_CONV_IC(post_ops_ok_, VERBOSE_UNSUPPORTED_POSTOP); jcp.typesize_in = types::data_type_size(src_d.data_type()); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp index 64f96c2a38d..bca64ff1f28 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp @@ -98,6 +98,12 @@ struct _jit_uni_x8s8s32x_fwd_kernel : public jit_generator { /* binary post-ops operand */ const Xbyak::Reg64 temp_offset_reg = r12; + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + + const Vmm vmm_d_weights = Vmm(0); + const Vmm vmm_d_bias = Vmm(1); + const Vmm vmm_wei = Vmm(0); /* used during bias/comp/scale section of store_output */ const Vmm vmm_bias = Vmm(0); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp index bfa20f0d33e..7a0ed6ede61 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp @@ -199,6 +199,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; @@ -331,6 +332,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_1d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); (*kernel_)(&p); @@ -470,6 +472,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g * sizeof(float); (*kernel_)(&p); }); @@ -640,6 +643,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t::execute_forward_3d( p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); (*kernel_)(&p); src_w += src_h_stride * jcp.stride_h; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp b/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp index 2b029c4526a..ecf690b596b 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_convolution.hpp @@ -75,6 +75,9 @@ struct jit_uni_x8s8s32x_convolution_fwd_t : public primitive_t { VERBOSE_UNSUPPORTED_POSTOP); VDISPATCH_CONV(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_CONV(zero_points_ok(), VERBOSE_UNSUPPORTED_ZP_CFG); + VDISPATCH_CONV( + !this->attr()->has_asymmetric_quantization(), + VERBOSE_UNSUPPORTED_ATTR); CHECK(jit_uni_x8s8s32x_fwd_kernel::init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, diff --git a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp index 9b03681d900..aa5aa49135f 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp @@ -270,6 +270,12 @@ status_t jit_uni_x8s8s32x_deconv_fwd_kernel::init_conf( const int sum_ind = p.find(primitive_kind::sum); jcp.with_sum = sum_ind != -1; + const int depthwise_ind = p.find(primitive_kind::depthwise); + jcp.with_depthwise = depthwise_ind != -1; + + const int quantization_ind = p.find(primitive_kind::quantization); + jcp.with_quantization = quantization_ind != -1; + const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); jcp.is_oc_scale = wei_scales.mask_ != 0; @@ -405,7 +411,7 @@ bool jit_uni_x8s8s32x_deconv_fwd_kernel::post_ops_ok(jit_conv_conf_t &jcp, const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { using namespace injector; - return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary}, + return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary, depthwise, quantization}, attr.post_ops_, &dst_d, false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/, false /*sum_requires_zp_zero*/, true /*sum_requires_same_params*/, @@ -422,7 +428,7 @@ _jit_uni_x8s8s32x_deconv_fwd_kernelparam1_, rhs_sp}; + const quantization_injector::static_params_t qsp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias}; + postops_injector_ = utils::make_unique< injector::jit_uni_postops_injector_t>( - this, jcp_.post_ops, bsp); + this, jcp_.post_ops, bsp, qsp); } } @@ -1056,10 +1064,21 @@ void _jit_uni_x8s8s32x_deconv_fwd_kernel::apply_postops(int ur_w, } } } + + std::map vmm_idx_off; + for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + vmm_idx_off.insert({vmm_out(ur, ocb).getIdx(), ocb * jcp_.oc_block * sizeof(float)}); + } + } + depthwise_injector::dynamic_params_t ddp {vmm_d_weights.getIdx(), vmm_d_bias.getIdx(), reg_d_weights, reg_d_bias, + ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + quantization_injector::dynamic_params_t qdp {ptr[this->param1 + GET_OFF(oc_off)], vmm_idx_off}; + const int nb_oc_block = jcp_.is_depthwise ? jcp_.nb_ch_blocking : jcp_.nb_oc_blocking; postops_injector_->compute_vector_range( - 16 - nb_oc_block * ur_w, 16, rhs_arg_params); + 16 - nb_oc_block * ur_w, 16, rhs_arg_params, ddp, qdp); } template @@ -1143,7 +1162,7 @@ void _jit_uni_x8s8s32x_deconv_fwd_kernel::store_output( if (p_sum_zp && *p_sum_zp != 0) { mov(reg_ptr_sum_zp_, reinterpret_cast(p_sum_zp)); } - if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum) + if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum || jcp_.with_depthwise || jcp_.with_quantization) apply_postops(ur_w, last_oc_block, p_sum_scale, p_sum_zp); if (jcp_.dst_scale) { mov(reg_ptr_dst_scales_, ptr[param1_ + GET_OFF(dst_scale)]); @@ -1573,6 +1592,8 @@ status_t jit_uni_x8s8s32x_deconvolution_fwd_t::execute_forward_1d( p.dst_zero_point = zp_dst; p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); ++start; @@ -1744,6 +1765,8 @@ status_t jit_uni_x8s8s32x_deconvolution_fwd_t::execute_forward_2d( p.dst_zero_point = zp_dst; p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); } if (jcp.loop_order == loop_ngc) @@ -1971,6 +1994,8 @@ status_t jit_uni_x8s8s32x_deconvolution_fwd_t::execute_forward_3d( p.dst_zero_point = zp_dst; p.dst_orig = dst; + p.oc_off = g_oc * sizeof(float); + (*kernel_)(&p); } diff --git a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp index 4385fb0265a..a90eb1deba9 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp @@ -140,6 +140,13 @@ struct _jit_uni_x8s8s32x_deconv_fwd_kernel : public jit_generator { int ur_w, int l_overflow, int r_overflow, bool h_padded); void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow, bool h_padded, bool last_oc_block); + + /* depthwise and quantization post ops */ + const Xbyak::Reg64 reg_d_weights = r15; + const Xbyak::Reg64 reg_d_bias = r13; + Vmm vmm_d_weights = Vmm(0); + Vmm vmm_d_bias = Vmm(1); + void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); void generate() override;