Skip to content

Commit

Permalink
cpu: resampling: fix zero padding issues with postops
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxinzen authored and tprimak committed Jan 13, 2023
1 parent 5bd5d52 commit aa52a51
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
56 changes: 38 additions & 18 deletions src/cpu/simple_resampling.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2022 Intel Corporation
* Copyright 2019-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,8 +47,9 @@ struct simple_resampling_kernel_t : public simple_resampling_base_t {
status_t execute(const exec_ctx_t &ctx) const override;

private:
using interpolate_fn_t = std::function<void(const src_data_t *,
dst_data_t *, ref_post_ops_t::args_t &, dim_t, dim_t, dim_t)>;
using interpolate_fn_t
= std::function<void(const src_data_t *, dst_data_t *,
ref_post_ops_t::args_t &, dim_t, dim_t, dim_t, const bool)>;

void fill_coeffs();
void fill_weights();
Expand Down Expand Up @@ -83,6 +84,7 @@ simple_resampling_kernel_t<src_type, dst_type>::simple_resampling_kernel_t(
stride_d_ = pd_->IH() * pd_->IW() * inner_stride_;
stride_h_ = pd_->IW() * inner_stride_;
stride_w_ = inner_stride_;
tail_size_ = pd_->C() % inner_stride_;
} else {
const memory_desc_wrapper diff_src_d(pd_->diff_src_md());
inner_stride_ = diff_src_d.blocking_desc().strides[pd_->ndims() - 1];
Expand All @@ -91,6 +93,7 @@ simple_resampling_kernel_t<src_type, dst_type>::simple_resampling_kernel_t(
stride_d_ = pd_->OH() * pd_->OW() * inner_stride_;
stride_h_ = pd_->OW() * inner_stride_;
stride_w_ = inner_stride_;
tail_size_ = pd_->C() % inner_stride_;
}
}

Expand Down Expand Up @@ -122,6 +125,7 @@ status_t simple_resampling_kernel_t<src_type, dst_type>::execute(
const int ID = pd_->ID();
const int IH = pd_->IH();
const int IW = pd_->IW();
const int NB_CH = utils::div_up(pd_->C(), inner_stride_);

if (pd_->is_fwd()) {
const auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
Expand All @@ -132,6 +136,9 @@ status_t simple_resampling_kernel_t<src_type, dst_type>::execute(
postops_args.ctx = &ctx;
postops_args.dst_md = pd_->dst_md();

const bool preserve_zero_padding
= (nsp0 + 1) % NB_CH == 0 && tail_size_ != 0;

for (dim_t ow = 0; ow < OW; ow++) {
const dim_t src_off = nsp0 * ID * IH * IW * inner_stride_;
const dim_t dst_off
Expand All @@ -140,8 +147,8 @@ status_t simple_resampling_kernel_t<src_type, dst_type>::execute(

postops_args.l_offset = dst_off;

interpolate_fn_(
src + src_off, dst + dst_off, postops_args, od, oh, ow);
interpolate_fn_(src + src_off, dst + dst_off, postops_args, od,
oh, ow, preserve_zero_padding);
}
});
} else {
Expand All @@ -157,7 +164,8 @@ status_t simple_resampling_kernel_t<src_type, dst_type>::execute(
= (nsp * ID * IH * IW + id * IH * IW + ih * IW + iw)
* inner_stride_;
interpolate_fn_(diff_dst + diff_dst_off,
diff_src + diff_src_off, empty_args, id, ih, iw);
diff_src + diff_src_off, empty_args, id, ih, iw,
false);
});
}

Expand Down Expand Up @@ -223,7 +231,7 @@ simple_resampling_kernel_t<src_type, dst_type>::create_nearest() const {
if (pd_->is_fwd()) {
return [&](const src_data_t *src, dst_data_t *dst,
ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
dim_t ow) {
dim_t ow, const bool preserve_zero_padding) {
const dim_t id = nearest_idx(od, pd_->OD(), pd_->ID());
const dim_t ih = nearest_idx(oh, pd_->OH(), pd_->IH());
const dim_t iw = nearest_idx(ow, pd_->OW(), pd_->IW());
Expand All @@ -235,7 +243,9 @@ simple_resampling_kernel_t<src_type, dst_type>::create_nearest() const {
innermost_el++) {
float res = static_cast<float>(src[offset + innermost_el]);

if (are_postops_set_) {
if (are_postops_set_
&& IMPLICATION(preserve_zero_padding,
innermost_el < tail_size_)) {
po_args.dst_val = dst[innermost_el];
ref_post_ops_.execute(res, po_args);
po_args.l_offset++;
Expand All @@ -247,7 +257,7 @@ simple_resampling_kernel_t<src_type, dst_type>::create_nearest() const {
} else {
return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
dim_t iw) {
dim_t iw, const bool preserve_zero_padding) {
auto ow_idx = [&](const float in_idx) -> dim_t {
return ceil_idx((in_idx * pd_->OW() / pd_->IW()) - 0.5f);
};
Expand All @@ -257,6 +267,7 @@ simple_resampling_kernel_t<src_type, dst_type>::create_nearest() const {
auto od_idx = [&](const float in_idx) -> dim_t {
return ceil_idx((in_idx * pd_->OD() / pd_->ID()) - 0.5f);
};
MAYBE_UNUSED(preserve_zero_padding);

const dim_t ow_start = ow_idx(iw) * stride_w_;
const dim_t oh_start = oh_idx(ih) * stride_h_;
Expand Down Expand Up @@ -288,7 +299,7 @@ simple_resampling_kernel_t<src_type, dst_type>::create_linear() const {
if (pd_->is_fwd()) {
return [&](const src_data_t *src, dst_data_t *dst,
ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
dim_t ow) {
dim_t ow, const bool preserve_zero_padding) {
const linear_coeffs_t &iw
= linear_coeffs_[pd_->OD() + pd_->OH() + ow];

Expand All @@ -301,7 +312,9 @@ simple_resampling_kernel_t<src_type, dst_type>::create_linear() const {
src[iw.idx[k] * stride_w_ + innermost_el])
* iw.wei[k];

if (are_postops_set_) {
if (are_postops_set_
&& IMPLICATION(preserve_zero_padding,
innermost_el < tail_size_)) {
po_args.dst_val = dst[innermost_el];
ref_post_ops_.execute(res, po_args);
po_args.l_offset++;
Expand All @@ -313,9 +326,10 @@ simple_resampling_kernel_t<src_type, dst_type>::create_linear() const {
} else {
return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
dim_t iw) {
dim_t iw, const bool preserve_zero_padding) {
const bwd_linear_coeffs_t &w
= bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw];
MAYBE_UNUSED(preserve_zero_padding);

PRAGMA_OMP_SIMD()
for (dim_t innermost_el = 0; innermost_el < inner_stride_;
Expand All @@ -342,7 +356,7 @@ simple_resampling_kernel_t<src_type, dst_type>::create_bilinear() const {
if (pd_->is_fwd()) {
return [&](const src_data_t *src, dst_data_t *dst,
ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
dim_t ow) {
dim_t ow, const bool preserve_zero_padding) {
const linear_coeffs_t &ih = linear_coeffs_[pd_->OD() + oh];
const linear_coeffs_t &iw
= linear_coeffs_[pd_->OD() + pd_->OH() + ow];
Expand All @@ -357,7 +371,9 @@ simple_resampling_kernel_t<src_type, dst_type>::create_bilinear() const {
+ iw.idx[k] * stride_w_ + innermost_el])
* ih.wei[j] * iw.wei[k];

if (are_postops_set_) {
if (are_postops_set_
&& IMPLICATION(preserve_zero_padding,
innermost_el < tail_size_)) {
po_args.dst_val = dst[innermost_el];
ref_post_ops_.execute(res, po_args);
po_args.l_offset++;
Expand All @@ -369,10 +385,11 @@ simple_resampling_kernel_t<src_type, dst_type>::create_bilinear() const {
} else {
return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
dim_t iw) {
dim_t iw, const bool preserve_zero_padding) {
const bwd_linear_coeffs_t &h = bwd_linear_coeffs_[pd_->ID() + ih];
const bwd_linear_coeffs_t &w
= bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw];
MAYBE_UNUSED(preserve_zero_padding);

PRAGMA_OMP_SIMD()
for (dim_t innermost_el = 0; innermost_el < inner_stride_;
Expand Down Expand Up @@ -402,7 +419,7 @@ simple_resampling_kernel_t<src_type, dst_type>::create_trilinear() const {
if (pd_->is_fwd()) {
return [&](const src_data_t *src, dst_data_t *dst,
ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
dim_t ow) {
dim_t ow, const bool preserve_zero_padding) {
const linear_coeffs_t &id = linear_coeffs_[od];
const linear_coeffs_t &ih = linear_coeffs_[pd_->OD() + oh];
const linear_coeffs_t &iw
Expand All @@ -420,7 +437,9 @@ simple_resampling_kernel_t<src_type, dst_type>::create_trilinear() const {
+ iw.idx[k] * stride_w_ + innermost_el])
* id.wei[i] * ih.wei[j] * iw.wei[k];

if (are_postops_set_) {
if (are_postops_set_
&& IMPLICATION(preserve_zero_padding,
innermost_el < tail_size_)) {
po_args.dst_val = dst[innermost_el];
ref_post_ops_.execute(res, po_args);
po_args.l_offset++;
Expand All @@ -432,11 +451,12 @@ simple_resampling_kernel_t<src_type, dst_type>::create_trilinear() const {
} else {
return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
dim_t iw) {
dim_t iw, const bool preserve_zero_padding) {
const bwd_linear_coeffs_t &d = bwd_linear_coeffs_[id];
const bwd_linear_coeffs_t &h = bwd_linear_coeffs_[pd_->ID() + ih];
const bwd_linear_coeffs_t &w
= bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw];
MAYBE_UNUSED(preserve_zero_padding);

PRAGMA_OMP_SIMD()
for (dim_t innermost_el = 0; innermost_el < inner_stride_;
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/simple_resampling.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2021 Intel Corporation
* Copyright 2019-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,6 +50,7 @@ struct simple_resampling_base_t {
dim_t stride_h_ = 0;
dim_t stride_w_ = 0;
dim_t inner_stride_ = 0;
dim_t tail_size_ = 0;
};

struct simple_resampling_fwd_t : public primitive_t {
Expand Down

0 comments on commit aa52a51

Please sign in to comment.