Skip to content

Commit

Permalink
[FORK][FEATURE] Updated sse41 jit convolutions to support padded chan…
Browse files Browse the repository at this point in the history
…nels
  • Loading branch information
dmitry-gorokhov authored and luweizhou2016 committed Nov 26, 2024
1 parent 63e956d commit 115e9fa
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 1 deletion.
13 changes: 13 additions & 0 deletions src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
jcp.mb = src_d.dims()[0];

jcp.oc = dst_d.dims()[1] / jcp.ngroups;
jcp.oc_without_padding = jcp.oc;
jcp.ic = src_d.dims()[1] / jcp.ngroups;

jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
Expand Down Expand Up @@ -645,6 +646,9 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,

const int simd_w = 4;

jcp.oc = rnd_up(jcp.oc, simd_w*2);
jcp.ic = rnd_up(jcp.ic, simd_w*2);

jcp.ic_block = jcp.oc_block = simd_w * 2;

args_ok = true && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
Expand Down Expand Up @@ -810,6 +814,15 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
return status::success;
}

void jit_sse41_1x1_conv_kernel_f32::init_scratchpad(
memory_tracking::registrar_t &scratchpad,
const jit_1x1_conv_conf_t &jcp) {
using namespace dnnl::impl::memory_tracking::names;

if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
scratchpad.book<float>(key_conv_padded_bias, sizeof(float) * jcp.oc);
}

} // namespace x64
} // namespace cpu
} // namespace impl
Expand Down
4 changes: 4 additions & 0 deletions src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP
#define CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP

#include "common/memory_tracking.hpp"
#include "common/c_types_map.hpp"
#include "common/memory.hpp"

Expand All @@ -39,6 +40,9 @@ struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator {
const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
int nthreads);

static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
const jit_1x1_conv_conf_t &jcp);

DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_1x1_conv_kernel_f32)

jit_1x1_conv_conf_t jcp;
Expand Down
10 changes: 10 additions & 0 deletions src/cpu/x64/jit_sse41_1x1_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace x64 {

using namespace dnnl::impl::status;
using namespace dnnl::impl::utils;
using namespace dnnl::impl::memory_tracking::names;

void jit_sse41_1x1_convolution_fwd_t::execute_forward(
const exec_ctx_t &ctx) const {
Expand All @@ -52,6 +53,15 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward(
: std::vector<const void *> {};

auto scratchpad = ctx.get_scratchpad_grantor();

if (pd()->wants_padded_bias()) {
auto padded_bias = scratchpad.get<data_t>(key_conv_padded_bias);
utils::array_copy(padded_bias, bias, kernel_->jcp.oc_without_padding);
utils::array_set(padded_bias + kernel_->jcp.oc_without_padding, 0.f,
kernel_->jcp.oc - kernel_->jcp.oc_without_padding);
bias = padded_bias;
}

parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),
Expand Down
3 changes: 3 additions & 0 deletions src/cpu/x64/jit_sse41_1x1_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ struct jit_sse41_1x1_convolution_fwd_t : public primitive_t {
dnnl_get_max_threads()));
if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));

auto scratchpad = scratchpad_registry().registrar();
jit_sse41_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);

return status::success;
}

Expand Down
19 changes: 19 additions & 0 deletions src/cpu/x64/jit_sse41_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace x64 {
using namespace dnnl::impl::format_tag;
using namespace dnnl::impl::prop_kind;
using namespace dnnl::impl::utils;
using namespace dnnl::impl::memory_tracking::names;

using namespace Xbyak;

Expand Down Expand Up @@ -398,6 +399,7 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
jcp.mb = src_d.dims()[0];

jcp.oc = dst_d.dims()[1] / jcp.ngroups;
jcp.oc_without_padding = jcp.oc;
jcp.ic = src_d.dims()[1] / jcp.ngroups;

jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
Expand Down Expand Up @@ -491,7 +493,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
VDISPATCH_CONV_IC(channel_pad_ok, VERBOSE_UNSUPPORTED_PAD_FEATURE,
"i/o and padded channel size mismatch");

bool ok_to_pad_channels = true && jcp.ngroups == 1;

const int simd_w = 8; // 2 SSE vectors processing at once
if (ok_to_pad_channels) {
jcp.oc = rnd_up(jcp.oc, simd_w);
if (mimo) {
jcp.ic = rnd_up(jcp.ic, simd_w);
}
}

jcp.ur_h = 1; /* no code-unrolling by h so far */
jcp.ur_w = 3;
Expand Down Expand Up @@ -549,6 +559,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
return status::success;
}

void jit_sse41_conv_fwd_kernel_f32::init_scratchpad(
memory_tracking::registrar_t &scratchpad,
const jit_conv_conf_t &jcp) {
using namespace dnnl::impl::memory_tracking::names;

if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
scratchpad.book<float>(key_conv_padded_bias, sizeof(float) * jcp.oc);
}

} // namespace x64
} // namespace cpu
} // namespace impl
Expand Down
4 changes: 4 additions & 0 deletions src/cpu/x64/jit_sse41_conv_kernel_f32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP
#define CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP

#include "common/memory_tracking.hpp"
#include "common/c_types_map.hpp"
#include "common/memory.hpp"

Expand All @@ -39,6 +40,9 @@ struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator {
const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
int nthreads);

static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
const jit_conv_conf_t &jcp);

DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_conv_fwd_kernel_f32)
jit_conv_conf_t jcp;
const primitive_attr_t &attr_;
Expand Down
10 changes: 10 additions & 0 deletions src/cpu/x64/jit_sse41_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace x64 {

using namespace dnnl::impl::status;
using namespace dnnl::impl::utils;
using namespace dnnl::impl::memory_tracking::names;

#define src_blk_off(f, n, c, h, w) \
(pd()->ndims() == 3) ? (f).blk_off(n, c, w) : (f).blk_off(n, c, h, w)
Expand Down Expand Up @@ -60,6 +61,15 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
const bool is_dst_layout_nxc
= one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc);

auto scratchpad = ctx.get_scratchpad_grantor();
if (pd()->wants_padded_bias()) {
auto padded_bias = scratchpad.get<data_t>(key_conv_padded_bias);
utils::array_copy(padded_bias, bias, kernel_->jcp.oc_without_padding);
utils::array_set(padded_bias + kernel_->jcp.oc_without_padding, 0.f,
kernel_->jcp.oc - kernel_->jcp.oc_without_padding);
bias = padded_bias;
}

parallel(jcp.nthr, [&](const int ithr, const int nthr) {
assert(nthr == jcp.nthr);

Expand Down
3 changes: 3 additions & 0 deletions src/cpu/x64/jit_sse41_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ struct jit_sse41_convolution_fwd_t : public primitive_t {
*src_md(), *weights_md(), *dst_md(), *attr(),
dnnl_get_max_threads()));

auto scratchpad = scratchpad_registry().registrar();
jit_sse41_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_);

return status::success;
}

Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(

const bool ok_to_pad_channels = true && !is_data_layout_nxc
&& jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups
&& one_of(isa, avx512_core, avx2);
&& one_of(isa, avx512_core, avx2, sse41);
if (ok_to_pad_channels) {
jcp.oc = rnd_up(jcp.oc, simd_w);
jcp.ic = rnd_up(jcp.oc, simd_w);
Expand Down

0 comments on commit 115e9fa

Please sign in to comment.