diff --git a/src/cpu/gemm_convolution_utils.cpp b/src/cpu/gemm_convolution_utils.cpp index 1c1761f6d80..2de4ddcf39f 100644 --- a/src/cpu/gemm_convolution_utils.cpp +++ b/src/cpu/gemm_convolution_utils.cpp @@ -996,7 +996,8 @@ void col2im(const conv_gemm_conf_t &jcp, const float *col, float *im, status_t init_conf(conv_gemm_conf_t &jcp, memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads) { + memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads, + bool check_postops) { const memory_desc_wrapper src_d(&src_md); const memory_desc_wrapper weights_d(&weights_md); const memory_desc_wrapper dst_d(&dst_md); @@ -1154,6 +1155,23 @@ status_t init_conf(conv_gemm_conf_t &jcp, CHECK(attr.set_default_formats(&dst_md)); +#if DNNL_X64 + // for x64 we need to check post-ops after tags init + if (check_postops) { + using namespace x64::injector; + static constexpr bool sum_at_pos_0_only = true; + static constexpr bool sum_requires_scale_one = true; + static constexpr bool sum_requires_zp_zero = true; + + VDISPATCH_CONV_IC( + post_ops_ok(post_ops_ok_args_t(x64::avx512_core, + {binary, eltwise, sum}, attr.post_ops_, &dst_d, + sum_at_pos_0_only, sum_requires_scale_one, + sum_requires_zp_zero)), + VERBOSE_UNSUPPORTED_POSTOP); + } +#endif + jcp.post_ops = attr.post_ops_; const int eltwise_ind = jcp.post_ops.find(primitive_kind::eltwise); diff --git a/src/cpu/gemm_convolution_utils.hpp b/src/cpu/gemm_convolution_utils.hpp index f1a795528d0..43e9784bc44 100644 --- a/src/cpu/gemm_convolution_utils.hpp +++ b/src/cpu/gemm_convolution_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2022 Intel Corporation +* Copyright 2016-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -117,7 +117,8 @@ void col2im(const conv_gemm_conf_t &jcp, const float *col, float *im, status_t init_conf(conv_gemm_conf_t &jcp, memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads); + memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads, + bool check_postops = false); void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb); diff --git a/src/cpu/x64/gemm_bf16_convolution.hpp b/src/cpu/x64/gemm_bf16_convolution.hpp index 1cf97bd8012..7fe15a2d36b 100644 --- a/src/cpu/x64/gemm_bf16_convolution.hpp +++ b/src/cpu/x64/gemm_bf16_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,25 +70,10 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t { dst_data_type), VERBOSE_UNSUPPORTED_ATTR); - { - using namespace x64::injector; - static constexpr bool sum_at_pos_0_only = true; - static constexpr bool sum_requires_scale_one = true; - static constexpr bool sum_requires_zp_zero = true; - const auto dst_md = memory_desc_wrapper(dst_md_); - - VDISPATCH_CONV( - post_ops_ok(post_ops_ok_args_t(avx512_core, - {binary, eltwise, sum}, attr()->post_ops_, - &dst_md, sum_at_pos_0_only, - sum_requires_scale_one, sum_requires_zp_zero)), - VERBOSE_UNSUPPORTED_POSTOP); - } - auto scratchpad = scratchpad_registry().registrar(); return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, - dnnl_get_max_threads()); + dnnl_get_max_threads(), true /* check_postops */); } bool is_postprocess_required() const {