diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index 021d5e55cad..d5ef2a7f877 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -240,20 +240,27 @@ status_t brgemm_blocking(brgemm_desc_t *brg) { brg->ldb = brg->load_dim / brg->ld_block; brg->ldb_tail = brg->load_dim % brg->ld_block; + const int max_vpad = nstl::max( + brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad); + int adj_ld_block2 = calculate_ldb_params(brg, 4); int max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); - // reduce 'ld_block2' to allow a larger 'bd_block' - const int max_vpad = nstl::max( - brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad); if (is_superset(brg->isa_impl, avx2) && max_bcast_block < max_vpad) { - adj_ld_block2 = calculate_ldb_params(brg, 2); - max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); + for (int try_ld_block2 = 2; try_ld_block2 > 0; --try_ld_block2) { + adj_ld_block2 = calculate_ldb_params(brg, try_ld_block2); + max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2); + if (max_bcast_block >= max_vpad) break; + } + // bcast block in brgemm kernel should be greater than virtual + // padding to avoid possible functional issues + if (max_bcast_block < max_vpad) return status::unimplemented; } - const int min_block = 1; + const int min_block = nstl::max(1, max_vpad); + float best_bd_block_eff = 0.f; - brg->bd_block = 1; + brg->bd_block = max_bcast_block; for (int bd_block = max_bcast_block; bd_block >= min_block; bd_block--) { const auto bd_block_disb = static_cast(brg->bcast_dim)