Skip to content

Commit

Permalink
cpu: aarch64: fix acl matmul dim guard for 4d tensor broadcast
Browse files Browse the repository at this point in the history
Signed-off-by: Ye Tao <[email protected]>
  • Loading branch information
taoye9 authored and Radu2k committed Dec 2, 2024
1 parent 583215d commit b3be239
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/cpu/aarch64/matmul/acl_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,26 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
// for e.g when ab in abcd is 1x1
bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1)
&& IMPLICATION(wei_batch > 1, src_batch == 1);

ACL_CHECK_SUPPORT(src_d.ndims() == 4 && src_batch != wei_batch && !batch_ok,
"matmul broadcast supported only for 3D shapes and 4D shapes when "
"ab is 1x1");

if (src_d.ndims() == 4 && src_batch == wei_batch
&& src_d.dims()[0] != wei_d.dims()[0]) { // 4D broadcast occurred
if (src_d.dims()[0] == 1 && wei_d.dims()[0] != 1) { // Broadcast src
ACL_CHECK_SUPPORT(
IMPLICATION(src_d.dims()[1] != 1, wei_d.dims()[1] == 1),
"acl only broadcasts one of src or wei at once");
}

if (wei_d.dims()[0] == 1 && src_d.dims()[0] != 1) { // Broadcast wei
ACL_CHECK_SUPPORT(
IMPLICATION(src_d.dims()[1] == 1, wei_d.dims()[1] != 1),
"acl only broadcasts one of src or wei at once");
}
}

// ACL does not support bias
bool with_bias = md.bias_desc.format_kind != format_kind::undef;
ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul");
Expand Down
1 change: 1 addition & 0 deletions tests/benchdnn/inputs/matmul/shapes_4d
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
74x16x54x64:74x16x64x54
1x1x35x64:113x16x64x35
1x16x38x64:105x1x64x38
1x3x35x64:3x1x64x35
74x16x54x64:1x1x64x54n"B_full_bcast"
74x6x1x253:1x1x253x1n"dot_prod_w_B_full_bcast"

0 comments on commit b3be239

Please sign in to comment.