Skip to content

Commit

Permalink
xe: gemm: fix 2d scale handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kealan-barbieri committed Aug 9, 2024
1 parent 0a077ec commit 894f16c
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions src/gpu/intel/ocl/gemm_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,17 @@ struct gemm_matmul_t : public gpu_primitive_t {
return status::success;
};

auto adjust_scales_mask
= [&](arg_scales_t &scales, int arg, int diff_dims) {
int mask = 0;
bool is_set = false;
CHECK(attr()->scales_.get(arg, &mask, &is_set));
mask = mask >> diff_dims;
if (is_set) {
CHECK(scales.set(arg, mask, 0, nullptr,
attr()->scales_.get(arg).data_type_));
}
return status::success;
};
auto adjust_scales_mask = [&](arg_scales_t &scales, int arg,
int diff_dims) {
int mask = 0, nd = 0;
bool is_set = false;
data_type_t dt = dnnl_data_type_undef;
dims_t dims = {};
CHECK(attr()->scales_.get(arg, &mask, &is_set, &nd, dims, &dt));
mask = mask >> diff_dims;
if (is_set) { CHECK(scales.set(arg, mask, nd, dims, dt)); }
return status::success;
};
if (!attr()->zero_points_.has_default_values()) {
CHECK(map_gemm_zp(DNNL_ARG_SRC, DNNL_ARG_B));
CHECK(map_gemm_zp(
Expand Down

0 comments on commit 894f16c

Please sign in to comment.