Skip to content

Commit

Permalink
[FORK][FEATURE] Introduced Depthwise and Quantization post ops
Browse files Browse the repository at this point in the history
Primitives that supports new post ops:
- Jit Convolutions (FP32,BF16,INT8)
- Jit Deconvolution (INT8)
- Jit,Ref Pooling (INT8)

AMX primitives: explicilty pass dst_type into has_default_values checks

Extended int8 AMX convolutions to support depthwise/quantization post ops

ONEDNN 3.2 migration squashed commits:
   - Correct assert for reg id in dw conv kernel

ONEDNN 3.5 squash list:
[FIX] fix avx2 int8 binary postops reg conflict
  • Loading branch information
dmitry-gorokhov authored and luweizhou2016 committed Nov 26, 2024
1 parent 115e9fa commit bf3eabd
Show file tree
Hide file tree
Showing 94 changed files with 2,078 additions and 237 deletions.
10 changes: 10 additions & 0 deletions include/oneapi/dnnl/dnnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,16 @@ dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
const_dnnl_post_ops_t post_ops, int index, int *mask);

dnnl_status_t DNNL_API dnnl_post_ops_append_depthwise(
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
const float* weights_data, const float* biases_data);

dnnl_status_t DNNL_API dnnl_post_ops_append_quantization(
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
const void* crop_low, const void* crop_high,
const void* input_scale, const void* input_shift,
const void* output_scale, const void* output_shift);

/// @} dnnl_api_attributes

/// @} dnnl_api_primitives
Expand Down
22 changes: 22 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,12 @@ enum class algorithm {
softmax_accurate = dnnl_softmax_accurate,
/// LogSoftmax, numerically stable
softmax_log = dnnl_softmax_log,

depthwise_scale_shift = dnnl_depthwise_scale_shift,
depthwise_prelu = dnnl_depthwise_prelu,

quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize,
quantization_quantize = dnnl_quantization_quantize,
};

/// Converts algorithm kind enum value from C++ API to C API type.
Expand Down Expand Up @@ -3924,6 +3930,22 @@ struct post_ops : public handle<dnnl_post_ops_t> {
error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask),
"could not get parameters of a binary post-op");
}

void append_depthwise(algorithm alg, const float* weights_data,
const float* biases_data) {
error::wrap_c_api(dnnl_post_ops_append_depthwise(get(),
convert_to_c(alg), weights_data, biases_data),
"could not append depthwise");
}

void append_quantization(algorithm alg,
const void* crop_low, const void* crop_high,
const void* input_scale, const void* input_shift,
const void* output_scale, const void* output_shift) {
error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), crop_low, crop_high,
input_scale, input_shift, output_scale, output_shift),
"could not append quantization");
}
};

/// @cond DO_NOT_DOCUMENT_THIS
Expand Down
10 changes: 10 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,10 @@ typedef enum {
dnnl_deconvolution,
/// An element-wise primitive.
dnnl_eltwise,
/// An depthwise-wise primitive.
dnnl_depthwise,
/// A quantization primitive.
dnnl_quantization,
/// An LRN primitive.
dnnl_lrn,
/// A batch normalization primitive.
Expand Down Expand Up @@ -2176,6 +2180,12 @@ typedef enum {
dnnl_softmax_accurate = 0x30000,
/// Logsoftmax
dnnl_softmax_log,

dnnl_depthwise_scale_shift = 0x3fff0,
dnnl_depthwise_prelu = 0x3fff1,

dnnl_quantization_quantize_dequantize = 0x4fff0,
dnnl_quantization_quantize = 0x4fff1,
} dnnl_alg_kind_t;

/// Flags for normalization primitives.
Expand Down
6 changes: 6 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ const alg_kind_t reduction_norm_lp_power_p_sum
= dnnl_reduction_norm_lp_power_p_sum;
const alg_kind_t softmax_accurate = dnnl_softmax_accurate;
const alg_kind_t softmax_log = dnnl_softmax_log;
const alg_kind_t depthwise_scale_shift = dnnl_depthwise_scale_shift;
const alg_kind_t depthwise_prelu = dnnl_depthwise_prelu;
const alg_kind_t quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize;
const alg_kind_t quantization_quantize = dnnl_quantization_quantize;
} // namespace alg_kind

using data_type_t = dnnl_data_type_t;
Expand Down Expand Up @@ -1949,6 +1953,8 @@ const primitive_kind_t reduction = dnnl_reduction;
const primitive_kind_t softmax = dnnl_softmax;
const primitive_kind_t layer_normalization = dnnl_layer_normalization;
const primitive_kind_t group_normalization = dnnl_group_normalization;
const primitive_kind_t depthwise = dnnl_depthwise;
const primitive_kind_t quantization = dnnl_quantization;

// Internal only primitive kinds.
const primitive_kind_t internal_only_start = (primitive_kind_t)(1 << 12);
Expand Down
6 changes: 6 additions & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,8 @@ const char *dnnl_prim_kind2str(dnnl_primitive_kind_t v) {
if (v == dnnl_softmax) return "softmax";
if (v == dnnl_layer_normalization) return "layer_normalization";
if (v == dnnl_group_normalization) return "group_normalization";
if (v == dnnl_depthwise) return "depthwise";
if (v == dnnl_quantization) return "quantization";
if (v == dnnl_primitive_kind_max) return "primitive_kind_max";
if (v == dnnl::impl::primitive_kind::sdpa) return "sdpa";
assert(!"unknown prim_kind");
Expand Down Expand Up @@ -1830,6 +1832,10 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) {
if (v == dnnl_reduction_norm_lp_power_p_sum) return "reduction_norm_lp_power_p_sum";
if (v == dnnl_softmax_accurate) return "softmax_accurate";
if (v == dnnl_softmax_log) return "softmax_log";
if (v == dnnl_depthwise_scale_shift) return "depthwise_scale_shift";
if (v == dnnl_depthwise_prelu) return "depthwise_prelu";
if (v == dnnl_quantization_quantize_dequantize) return "quantization_quantize_dequantize";
if (v == dnnl_quantization_quantize) return "quantization_quantize";
assert(!"unknown alg_kind");
return "unknown alg_kind";
}
Expand Down
2 changes: 2 additions & 0 deletions src/common/ittnotify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ void primitive_task_start(primitive_kind_t kind) {
CASE(layer_normalization),
CASE(group_normalization),
CASE(sdpa),
CASE(depthwise),
CASE(quantization),
};
#undef CASE
int kind_idx = (int)kind;
Expand Down
37 changes: 37 additions & 0 deletions src/common/math_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,43 @@ inline float stochastic_round_fwd(
return r;
}

inline float get_bias(const char *bias, size_t offset, data_type_t data_type) {
if (!bias) return 0.0f;

#define CASE(dt) \
case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]

switch (data_type) {
CASE(data_type::s8);
CASE(data_type::u8);
CASE(data_type::bf16);
CASE(data_type::s32);
CASE(data_type::f32);
default: assert(!"unimplemented");
}
return 0; // never happens (should probably be a NaN)
#undef CASE
}

inline float get_sum(char *sum, size_t offset, data_type_t data_type)
{
if (!sum)
return 0.0f;

#define CASE(dt) \
case dt: return (float)((const prec_traits<dt>::type *)sum)[offset]

switch (data_type) {
CASE(data_type::s8);
CASE(data_type::u8);
CASE(data_type::s32);
CASE(data_type::f32);
default: assert(!"unimplemented");
}
return 0; // never happens (should probably be a NaN)
#undef CASE
}

} // namespace math
} // namespace impl
} // namespace dnnl
Expand Down
4 changes: 4 additions & 0 deletions src/common/nstl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ class vector : public c_compatible {
}
void clear() { _impl.clear(); }
void push_back(const T &t) { _impl.push_back(t); }
template<typename... Args>
void emplace_back(Args&&... args) {
_impl.emplace_back(std::forward<Args>(args)...);
}
void resize(size_type count) { _impl.resize(count); }
void reserve(size_type count) { _impl.reserve(count); }
};
Expand Down
97 changes: 97 additions & 0 deletions src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,29 @@ status_t scales_t::set(dim_t count, int mask, const float *scales) {
return status::success;
}


template <typename T>
status_t shifts_t<T>::set(int count, int mask, const T *shifts) {
cleanup();

count_ = count;
mask_ = mask;

if (count_ == 1) {
shifts_ = shifts_buf_;
utils::array_set(shifts_, shifts[0], shifts_buf_size);
} else {
shifts_ = (T *)impl::malloc(count_ * sizeof(*shifts_), 64);
if (shifts_ == nullptr)
return status::out_of_memory;

for (int c = 0; c < count_; ++c)
shifts_[c] = shifts[c];
}

return status::success;
}

status_t zero_points_t::get(int arg, int *mask, data_type_t *dt) const {
if (mask) *mask = get_mask(arg);
if (dt) *dt = get_data_type(arg);
Expand Down Expand Up @@ -182,6 +205,14 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
#undef CHECK_ARG
}

bool primitive_attr_t::has_asymmetric_quantization() const {
return true
&& output_scales_.has_default_values()
&& rnn_data_qparams_.has_default_values()
&& rnn_weights_qparams_.has_default_values()
&& (!input_zero_points_.has_default_values() || !weights_zero_points_.has_default_values());
}

bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
using smask_t = skip_mask_t;
bool ok = true;
Expand Down Expand Up @@ -313,6 +344,47 @@ status_t post_ops_t::append_prelu(int mask) {
return success;
}

status_t post_ops_t::append_depthwise(alg_kind_t alg, const float* weights_data, const float* biases_data) {
using namespace dnnl::impl::alg_kind;
if (len() == post_ops_limit) return out_of_memory;
bool known_alg = one_of(alg, depthwise_scale_shift, depthwise_prelu);
if (!known_alg)
return invalid_arguments;

entry_.emplace_back();
auto &e = entry_.back();
e.kind = primitive_kind::depthwise;
e.depthwise.alg = alg;
e.depthwise.weights_data = weights_data;
e.depthwise.biases_data = biases_data;

return success;
}

status_t post_ops_t::append_quantization(alg_kind_t alg,
const void* crop_low, const void* crop_high,
const void* input_scale, const void* input_shift,
const void* output_scale, const void* output_shift) {
using namespace dnnl::impl::alg_kind;
if (len() == post_ops_limit) return out_of_memory;
bool known_alg = one_of(alg, quantization_quantize_dequantize, quantization_quantize);
if (!known_alg)
return invalid_arguments;

entry_.emplace_back();
auto &e = entry_.back();
e.kind = primitive_kind::quantization;
e.quantization.alg = alg;
e.quantization.crop_low_data = reinterpret_cast<const shifts_t<float>*>(crop_low);
e.quantization.crop_high_data = reinterpret_cast<const shifts_t<float>*>(crop_high);
e.quantization.input_scale_data = reinterpret_cast<const scales_t*>(input_scale);
e.quantization.input_shift_data = reinterpret_cast<const shifts_t<float>*>(input_shift);
e.quantization.output_scale_data = reinterpret_cast<const scales_t*>(output_scale);
e.quantization.output_shift_data = reinterpret_cast<const shifts_t<float>*>(output_shift);

return success;
}

bool post_ops_t::defined() const {
for (int idx = 0; idx < len(); ++idx) {
auto kind = entry_[idx].kind;
Expand All @@ -327,6 +399,10 @@ bool post_ops_t::defined() const {
primitive_kind::prelu,
primitive_kind::convolution)) {
// binary is always defined
} else if (kind == primitive_kind::depthwise) {
// depthwise is always defined
} else if (kind == primitive_kind::quantization) {
// quantization is always defined
} else {
assert(!"unreachable");
}
Expand Down Expand Up @@ -787,6 +863,23 @@ status_t dnnl_post_ops_get_params_prelu(
return success;
}

status_t dnnl_post_ops_append_depthwise(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
const float* weights_data, const float* biases_data) {
if (post_ops == nullptr) return invalid_arguments;

return post_ops->append_depthwise(alg, weights_data, biases_data);
}

status_t dnnl_post_ops_append_quantization(post_ops_t *post_ops, alg_kind_t kind,
const void* crop_low, const void* crop_high,
const void* input_scale, const void* input_shift,
const void* output_scale, const void* output_shift) {
if (post_ops == nullptr)
return invalid_arguments;

return post_ops->append_quantization(kind, crop_low, crop_high, input_scale, input_shift, output_scale, output_shift);
}

status_t dnnl_primitive_attr_set_rnn_data_qparams(
primitive_attr_t *attr, const float scale, const float shift) {
if (attr == nullptr) return invalid_arguments;
Expand Down Expand Up @@ -854,3 +947,7 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams(

return attr->rnn_tparams_.set(mode, ngates, scales, cscale);
}

template struct dnnl::impl::shifts_t<uint8_t>;
template struct dnnl::impl::shifts_t<int32_t>;
template struct dnnl::impl::shifts_t<float>;
Loading

0 comments on commit bf3eabd

Please sign in to comment.