Skip to content

Commit

Permalink
gpu: jit: conv: enable stride support
Browse files Browse the repository at this point in the history
  • Loading branch information
dyoussif committed Dec 11, 2024
1 parent 72771d1 commit d0943f2
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 8 deletions.
13 changes: 13 additions & 0 deletions src/gpu/intel/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ status_t conv_problem_t::init(
with_bias = conv_pd->with_bias();
with_groups = conv_pd->with_groups();
with_sum = with_sum_post_op();
memory_desc_wrapper mdw_src(conv_pd->invariant_src_md());
memory_desc_wrapper mdw_wei(conv_pd->invariant_wei_md());
memory_desc_wrapper mdw_dst(conv_pd->invariant_dst_md());

strided = (mdw_src.is_plain() && !mdw_src.is_dense())
|| (mdw_wei.is_plain() && !mdw_wei.is_dense())
|| (mdw_dst.is_plain() && !mdw_dst.is_dense());

src_data_type = conv_pd->invariant_src_md()->data_type;
wei_data_type = conv_pd->invariant_wei_md()->data_type;
Expand Down Expand Up @@ -511,6 +518,8 @@ struct goi_block_t {
// Matches the user-provided descriptor against the list of supported plain tags.
std::string get_plain_user_tag(
const conv_problem_t &prb, const memory_desc_t &md, bool is_wei) {
memory_desc_wrapper mdw(md);
if (mdw.is_plain() && !mdw.is_dense()) return "user";
if (is_wei) {
std::vector<const char *> plain_non_group_wei_tags
= {"abx", "axb", "xba"};
Expand Down Expand Up @@ -644,6 +653,10 @@ void init_data_tags(const conv_config_t &cfg, const memory_desc_t &src_md,
// Use plain tag for output to avoid extra reorders.
if (src_output) src_tag = user_src_tag;
if (dst_output) dst_tag = user_dst_tag;

if (user_src_req == "user") src_tag = user_src_tag = "user";
if (user_wei_req == "user") wei_tag = user_wei_tag = "user";
if (user_dst_req == "user") dst_tag = user_dst_tag = "user";
}

status_t init_tensor_layouts(
Expand Down
23 changes: 17 additions & 6 deletions src/gpu/intel/jit/conv/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,30 @@ const std::vector<pvar_t> &conv_padding_dims() {
return _padding_dims;
}

bool can_reduce_to_1d(const memory_desc_t &out_md, const post_ops_t &post_ops) {
int ndims = out_md.ndims;
bool can_reduce_to_1d(const memory_desc_t &md, const post_ops_t &post_ops) {
int ndims = md.ndims;
int sp_ndims = ndims - 2;
int non_one_sp_ndims = 0;
for (int i = ndims - sp_ndims; i < ndims; i++) {
if (out_md.dims[i] != 1) non_one_sp_ndims++;
auto &strides = md.format_desc.blocking.strides;
dim_t sp_size = strides[ndims - 1];
bool sp_dense = true;
for (int i = ndims - 1; i >= ndims - sp_ndims; i--) {
if (md.dims[i] != 1) non_one_sp_ndims++;
if (strides[i] != sp_size) sp_dense = false;
sp_size *= md.dims[i];
}
if (non_one_sp_ndims == 1) return true;
memory_desc_wrapper mdw(md);
bool strided = mdw.is_plain() && !sp_dense;
if (strided) return false;
for (int i = 0; i < post_ops.len(); i++) {
auto &po = post_ops.entry_[i];
int mask = 0;
if (po.is_prelu()) {
mask = po.prelu.mask;
} else if (po.is_binary()) {
mask = utils::get_dims_mask(
out_md.dims, po.binary.src1_desc.dims, ndims);
md.dims, po.binary.src1_desc.dims, ndims);
}
// If the post-op is applied per D/H/W dimension then it cannot be
// transformed to 1D.
Expand All @@ -196,7 +204,10 @@ bool can_reduce_to_1d(const memory_desc_t &out_md, const post_ops_t &post_ops) {
void conv_problem_t::normalize_shape() {
normalize_conv_shape(id, od, kd, sd, dd, pd, ih, oh, kh, sh, dh, ph, iw, ow,
kw, sw, dw, pw,
can_reduce_to_1d(c_md(), conv_pd->attr()->post_ops_), dhw_map);
can_reduce_to_1d(c_md(), conv_pd->attr()->post_ops_)
&& can_reduce_to_1d(a_md(), post_ops_t())
&& can_reduce_to_1d(b_md(), post_ops_t()),
dhw_map);
}

const memory_desc_t &conv_problem_t::a_md() const {
Expand Down
1 change: 1 addition & 0 deletions src/gpu/intel/jit/conv/problem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class conv_problem_t {
bool with_sum = false;
bool is_dw = false;
bool ab_swap_transpose = false;
bool strided = false;

int ndims = 0;
dim_t mb = 0; // Batch size.
Expand Down
5 changes: 3 additions & 2 deletions src/gpu/intel/jit/conv/tiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,8 +1123,9 @@ conv_blocking_scheme_list_t get_blocking_schemes_bwd_w(
bool k_is_mb = (k_iter_dim == pvars::mb);
bool k_is_ow = (k_iter_dim == pvars::ow);
bool small_ic = is_small_ic(cfg.prb());
ret.add(k_is_mb, conv_schemes::bwd_w_T_io_I_ion);
ret.add(k_is_ow, conv_schemes::bwd_w_T_io_I_iow);
bool strided = cfg.prb().strided;
ret.add(k_is_mb || strided, conv_schemes::bwd_w_T_io_I_ion);
ret.add(k_is_ow || strided, conv_schemes::bwd_w_T_io_I_iow);
ret.add(k_is_mb && small_ic, conv_schemes::bwd_w_T_io_I_kon);
ret.add(k_is_mb && small_ic, conv_schemes::bwd_w_T_io_I_ikon);
ret.add(k_is_ow && small_ic, conv_schemes::bwd_w_T_io_I_ikow);
Expand Down
1 change: 1 addition & 0 deletions src/gpu/intel/jit/ir/tensor_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ inline layout_t make_layout(const memory_desc_t &md) {
}

inline layout_t make_layout(const memory_desc_t &md, const std::string &tag) {
if (tag == "user") return layout_t(md);
return layout_t(md, tag, /*do_normalize=*/false);
}

Expand Down

0 comments on commit d0943f2

Please sign in to comment.