Skip to content

Commit

Permalink
xe: jit: conv: fix typed scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
kealan-barbieri committed Dec 18, 2024
1 parent 4bb3de3 commit accacad
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/gpu/intel/jit/ir/epilogue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class post_op_tensor_t {
// Assign new f32 layout and buffer.
reg_layout_ = std::move(f32_layout);
reg_buf_ = std::move(f32_buf);
info_.retype(type_t::f32());

return ret;
}
Expand Down
20 changes: 15 additions & 5 deletions src/gpu/intel/jit/ir/post_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,25 @@ post_op_context_t::post_op_context_t(const primitive_attr_t &attr,
int src_scales_mask = 0;
int wei_scales_mask = 0;
int dst_scales_mask = 0;
type_t src_scales_type, wei_scales_type, dst_scales_type;
for (int i = 0; i < (int)scale_args.size(); i++) {
auto buf = kernel_info.find_arg(
scale_args[i].first, /*allow_empty=*/true);
if (buf.is_empty()) continue;
int key = kernel_info.key(scale_args[i].first)
& ~DNNL_ARG_ATTR_SCALES;
int mask = attr.scales_.get(key).mask_;
auto scales = attr.scales_.get(key);
if (scales.has_default_values()) continue;
int mask = scales.mask_;
auto sc_type = scales.data_type_ == data_type::undef
? type_t::f32()
: scales.data_type_;
view_t view;
switch (key) {
case DNNL_ARG_SRC:
ir_assert(mask == 0);
view = po_vm_.create_view(type_t::f32(), mask);
src_scales_type = sc_type;
view = po_vm_.create_view(sc_type, mask);
src_scales = add_input_tensor(view, buf);
src_scales_mask = mask;
break;
Expand All @@ -63,14 +70,15 @@ post_op_context_t::post_op_context_t(const primitive_attr_t &attr,
// XXX: per_oc for BWD_D is treated as per_ic assuming it's
// called from deconvolution.
ir_assert(utils::one_of(mask, 0, 1, 3));
view = po_vm_.create_view(
type_t::f32(), (mask) ? 1 << 1 : 0);
wei_scales_type = sc_type;
view = po_vm_.create_view(sc_type, (mask) ? 1 << 1 : 0);
wei_scales = add_input_tensor(view, buf);
wei_scales_mask = mask;
break;
case DNNL_ARG_DST: // Invert dst scales right after load.
ir_assert(utils::one_of(mask, 0, 2));
view = po_vm_.create_view(type_t::f32(), mask);
dst_scales_type = sc_type;
view = po_vm_.create_view(sc_type, mask);
dst_scales = add_input_tensor(view, buf);
dst_scales_mask = mask;
break;
Expand Down Expand Up @@ -273,6 +281,8 @@ bool post_op_context_t::init_need_to_restore_zero_padding(
if (zp_cfg.do_dst_compensation && zp_cfg.is_common_dst_zero_point
&& out_md.dims[1] != out_md.padded_dims[1])
return true;
auto dst_scales = attr.scales_.get(DNNL_ARG_DST);
if (!dst_scales.has_default_values() && dst_scales.mask_ != 0) return true;
return false;
}

Expand Down
2 changes: 2 additions & 0 deletions src/gpu/intel/jit/ir/post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class post_op_tensor_info_t {
return ret;
}

void retype(const type_t &new_type) { view_ = view_.retype(new_type); }

void require_masked_update() { needs_masked_update_ = true; }

private:
Expand Down

0 comments on commit accacad

Please sign in to comment.