From 188ae7f3e3410a76a81d10b2b2bed3be7afdc307 Mon Sep 17 00:00:00 2001 From: Alexey Makarevich Date: Mon, 21 Oct 2024 16:13:43 -0700 Subject: [PATCH] matmul: x64: added support for bf16,f16 bias dt --- src/cpu/x64/brgemm/brgemm.cpp | 10 ++++++---- src/cpu/x64/matmul/brgemm_matmul.cpp | 5 ++++- tests/benchdnn/inputs/matmul/test_matmul_fp8 | 6 +++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index 9feacb7207c..56d5c83a47c 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -348,13 +348,15 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, data_type::f16))) return status::unimplemented; const auto bias_f8_e5m2_compatible - = one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e5m2) + = one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16, + data_type::f8_e5m2) && one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16, - data_type::f8_e5m2, data_type::f8_e4m3); + data_type::bf16, data_type::f8_e5m2, data_type::f8_e4m3); const auto bias_f8_e4m3_compatible - = one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e4m3) + = one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16, + data_type::f8_e4m3) && one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16, - data_type::f8_e4m3, data_type::f8_e5m2); + data_type::bf16, data_type::f8_e4m3, data_type::f8_e5m2); if (!IMPLICATION(brg->is_fp8, bias_f8_e5m2_compatible || bias_f8_e4m3_compatible)) return status::unimplemented; diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 250ab7afc07..b17cac2dd09 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -70,7 +70,10 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { const bool is_bia_dt_correct = IMPLICATION(is_int8 == true, one_of(bia_dt, f32, s32, s8, u8, bf16)) - && IMPLICATION(!is_int8, one_of(bia_dt, f32, src_dt)); + && IMPLICATION( + is_f8 == true, one_of(bia_dt, f32, f16, bf16, src_dt)) + && IMPLICATION( + !(is_int8 || is_f8), one_of(bia_dt, f32, src_dt)); return IMPLICATION(with_bias(), is_bia_dt_correct && is_bias_1xN()); }; diff --git a/tests/benchdnn/inputs/matmul/test_matmul_fp8 b/tests/benchdnn/inputs/matmul/test_matmul_fp8 index 43f17f67568..ddd1445d8de 100644 --- a/tests/benchdnn/inputs/matmul/test_matmul_fp8 +++ b/tests/benchdnn/inputs/matmul/test_matmul_fp8 @@ -4,7 +4,7 @@ --dt=f8_e4m3:f8_e4m3:f32,f8_e4m3,f8_e5m2:f8_e5m2:f32,f8_e5m2 --stag=ab,ba --wtag=ab,ba --dtag=ab --runtime_dims_masks=0,2:1,1:0,3:1 ---bia_dt=undef,f32 --bia_mask=2 +--bia_dt=undef,f32,f16,bf16 --bia_mask=2 --attr-scales= --attr-post-ops= @@ -21,8 +21,8 @@ --stag=ba --wtag=ab,ba --dtag=ab --runtime_dims_masks=3:1,3:3 ---bia_dt=f8_e4m3,f8_e5m2 --bia_mask=1,2,3 ---attr-scales=src:common:0.25+wei:common:0.5+dst:common:2.25 +--bia_dt=f8_e4m3,f8_e5m2,f16,bf16 --bia_mask=1,2,3 +--attr-scales=src:common:0.25+wei:common:0.5+dst:common:4 --attr-post-ops=add:f32,sum+mul:s32:per_oc+linear:2:-1 --batch=shapes_2d