Skip to content

Commit

Permalink
gpu: jit: conv: fix performance with OHWI weights
Browse files Browse the repository at this point in the history
  • Loading branch information
echeresh committed Jan 23, 2023
1 parent c616453 commit 2d0b31e
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/gpu/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,18 @@ bool can_use_2d_send(const conv_config_t &cfg, const layout_t &l, bool is_a) {
// 2D messages does not support vnni format with 4 byte elements
if (type_t(prb.b_data_type).size() >= 4) return false;

auto is_plain_wei_ok = [&]() {
if (l.is_empty()) return true;
for (auto *t : {"xba", "xab", "axb"}) {
if (matches_tag_strict(l, t)) return true;
}
return false;
};

auto is_plain_ok = [&]() {
if (is_a || prb.is_bwd_w) return matches_tag_strict(l, "axb");
if (is_b && l.is_empty()) return true;
if (is_b && prb.is_fwd) return matches_tag_strict(l, "xba");
if (is_b && prb.is_bwd_d) return matches_tag_strict(l, "xab");
bool is_wei = (is_b && prb.is_fwd) || (is_b && prb.is_bwd_d);
if (is_wei) return is_plain_wei_ok();
return false;
};

Expand Down

0 comments on commit 2d0b31e

Please sign in to comment.