From 548d5d67cec2589bf04aed804ab8c7c37ac8d13c Mon Sep 17 00:00:00 2001 From: Aditya Tewari Date: Fri, 8 Nov 2024 14:32:22 +0000 Subject: [PATCH] cpu: aarch64: enable bf16f32 matmul benchdnn: `ONEDNN_VERBOSE=all ./benchdnn --matmul --dt=bf16:bf16:f32 32x128:128x32` Before: onednn_verbose,v1,primitive,create:dispatch,matmul,cpu,matmul,gemm:acl,undef,src:bf16:a:any:any::f0 wei:bf16:a:any:any::f0 dst:f32:a:any:any::f0,,,32x128:128x32,unsupported datatype combination,src/cpu/aarch64/matmul/acl_matmul.cpp:89 After: (picks up ACL kernel) onednn_verbose,v1,primitive,exec:external,CpuGemmAssemblyWrapperKernel/sve_ffhybrid_bf16fp32_mmla_6x4VL,0.242188 onednn_verbose,v1,primitive,exec,cpu,matmul,gemm:acl,undef,src:bf16:a:blocked:ab::f0 wei:bf16:a:blocked:BA8b4a::f0 dst:f32:a:blocked:ab::f0,,,32x128:128x32,0.267822 --- src/cpu/aarch64/matmul/acl_matmul.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/aarch64/matmul/acl_matmul.cpp index 3d3e95a491d..64d594c6a9d 100644 --- a/src/cpu/aarch64/matmul/acl_matmul.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul.cpp @@ -76,12 +76,18 @@ status_t acl_matmul_t::pd_t::init(engine_t *engine) { = utils::everyone_is(data_type::bf16, src_md()->data_type, weights_md()->data_type, dst_md()->data_type) && platform::has_data_type_support(data_type::bf16); + const bool is_bf16f32_ok + = utils::everyone_is(data_type::bf16, src_md()->data_type, + weights_md()->data_type) + && utils::everyone_is(data_type::f32, dst_md()->data_type) + && platform::has_data_type_support(data_type::bf16); // we need to save this state as it can change inside set_default_formats() weights_format_kind_ = weights_md_.format_kind; VDISPATCH_MATMUL(is_dense_format_kind(), VERBOSE_UNSUPPORTED_SPARSE_CFG); - VDISPATCH_MATMUL(utils::one_of(true, is_fp32_ok, is_fp16_ok, is_bf16_ok), + VDISPATCH_MATMUL(utils::one_of(true, is_fp32_ok, is_fp16_ok, is_bf16_ok, + is_bf16f32_ok), VERBOSE_UNSUPPORTED_DT_CFG); VDISPATCH_MATMUL(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG);