Skip to content

Commit

Permalink
gpu: jit: conv: fix matches_tag()
Browse files Browse the repository at this point in the history
  • Loading branch information
echeresh authored and karturov committed Jan 9, 2023
1 parent efd4737 commit d3af877
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/gpu/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,18 @@ void set_default_format(memory_desc_t &md, const std::string &tag) {
md = make_layout(md, tag).to_dnnl(md.dims);
}

bool matches_tag(const layout_t &layout, const std::string &tag) {
bool matches_tag(const layout_t &layout, const std::string &tag,
const std::vector<dim_t> &dims) {
if (layout.is_empty()) return false;
auto tag_layout = make_layout(layout.type(), layout.dims(), tag);
auto tag_layout = make_layout(layout.type(), dims, tag);
if (layout != tag_layout) return false;
return true;
}

bool matches_tag(const layout_t &layout, const std::string &tag) {
return matches_tag(layout, tag, layout.dims());
}

bool matches_tag_strict(const layout_t &layout, const std::string &tag) {
if (layout.is_empty()) return false;
auto tag_layout = make_layout(layout.type(), layout.dims(), tag);
Expand All @@ -67,7 +72,8 @@ bool matches_tag_strict(const layout_t &layout, const std::string &tag) {

bool matches_tag(const memory_desc_t &md, const std::string &tag) {
if (md.format_kind == format_kind::any) return false;
return matches_tag(make_layout(md), tag);
std::vector<dim_t> dims(md.dims, md.dims + md.ndims);
return matches_tag(make_layout(md), tag, dims);
}

bool matches_tag_strict(const memory_desc_t &md, const std::string &tag) {
Expand Down

0 comments on commit d3af877

Please sign in to comment.