Skip to content

Commit

Permalink
src: cpu: reorder: jit: fix zero-point compensation cases with tail
Browse files Browse the repository at this point in the history
  • Loading branch information
msotoflo authored and vpirogov committed Sep 7, 2022
1 parent 73b7572 commit b340cba
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion src/cpu/x64/jit_uni_reorder_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,44 @@ static bool is_with_groups(const memory_desc_t &dst_md) {
dst_d.extra().asymm_compensation_mask);
}

static inline int get_next_parent_node(node_t *nodes, int ndims, int cur_node) {
const int cur_id = nodes[cur_node].dim_id;
for (int d = cur_node + 1; d < ndims; ++d) {
if (nodes[d].dim_id == cur_id) return d;
}
return -1;
}

static void prb_set_compensation_strides(prb_t &p) {

auto require_n_stride = [&](int cur_node) -> bool {
const int parent = get_next_parent_node(p.nodes, p.ndims, cur_node);
if (parent < 0) return false;

const size_t p_n = p.nodes[parent].n;

// if 'parent_node.n' is larger than 1, then cur_node stride
// is 'cur_node.n'
return p_n > size_t(1);
};

const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp;
if (!compensation_needed) return;
int mask = p.compensation_mask;
ptrdiff_t cs = 1;
for (int d = 0; d < p.ndims; ++d) {
if (mask & (1 << p.nodes[d].dim_id)) {

// correct cases when 'cs' exceeds output stride
if (cs > p.nodes[d].os) cs = p.nodes[d].os;

p.nodes[d].cs = cs;
cs = cs * p.nodes[d].n;
const bool n_stride = require_n_stride(d);
if (p.nodes[d].tail_size > 0 && (!p.nodes[d].is_zero_pad_needed)
&& (!n_stride))
cs *= p.nodes[d].tail_size;
else
cs *= p.nodes[d].n;
}
}
}
Expand Down

0 comments on commit b340cba

Please sign in to comment.