Skip to content

Commit

Permalink
xe: ocl: gemm: fix gemm_with_post_ops accumulator type
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler committed Dec 19, 2024
1 parent 7e450f8 commit a625e4d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
#else
ACC_DATA_T acc = SRC_TO_ACC(src[data_idx]);
#endif
float accumulator = acc;
float accumulator = convert_float(acc);
if ((d0 == D0_WO_PADDING && d1 == D1_WO_PADDING && d2 == D2_WO_PADDING
&& d3 == D3_WO_PADDING)
|| (d0 < D0_WO_PADDING && d1 < D1_WO_PADDING && d2 < D2_WO_PADDING
Expand All @@ -116,7 +116,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
const float b_scale
= B_SCALES ? WEI_SCALES_TO_REF(b_scales[scale_stride * d3]) : 1;
#endif
acc *= A_SCALE * b_scale;
accumulator *= A_SCALE * b_scale;
#endif

#if WITH_BIAS == 1
Expand All @@ -127,7 +127,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
#else
size_t bia_idx = BIA_OFF(d0, d1, 0, 0, 0);
#endif
acc += BIA_TO_ACC(bias[bia_idx]);
accumulator += BIA_TO_ACC(bias[bia_idx]);
#endif

// Apply postops
Expand All @@ -136,7 +136,6 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
sum_src = DST_TO_ACC(dst[data_idx]);
#endif

accumulator = acc;
#if NDIMS == 2
APPLY_POST_OPS_SERIAL(accumulator, float, sum_src, float, d0, 1, d1, 1,
0, 1, 0, 1, 0, 1, 0, 1);
Expand Down

0 comments on commit a625e4d

Please sign in to comment.