diff --git a/tests/benchdnn/dnnl_common.cpp b/tests/benchdnn/dnnl_common.cpp index c534a9ce23b..9f97eb005e9 100644 --- a/tests/benchdnn/dnnl_common.cpp +++ b/tests/benchdnn/dnnl_common.cpp @@ -1494,19 +1494,20 @@ int update_ref_mem_map_from_prim(dnnl_primitive_t prim_ref, // have dedicated query mechanism for those. Process potential outcomes: while (query_md_ndims(ref_md) == 0) { bool is_scales_arg = (exec_arg & DNNL_ARG_ATTR_SCALES); - // Ref memory for scales is f32, the library expects it same data type. - // Skip replacement. + // Scales received data type support in the library. The reference + // primitive expects them in the same data type. if (is_scales_arg) { - skip_replace = true; + prim_ref_mem = dnn_mem_t( + library_mem.md_, library_mem.dt(), tag::abx, ref_engine); break; } bool is_zero_point_arg = (exec_arg & DNNL_ARG_ATTR_ZERO_POINTS); - // Ref memory for zps is f32, but the library expects it in s32. Update - // the memory and proceed to replacement. + // Zero-points received data type support in the library. The reference + // primitive expects them in the same data type. if (is_zero_point_arg) { prim_ref_mem = dnn_mem_t( - library_mem.md_, dnnl_s32, tag::abx, ref_engine); + library_mem.md_, library_mem.dt(), tag::abx, ref_engine); break; }