Skip to content

Commit

Permalink
x64: brgemm: avx: bd_block should not be smaller than vpad
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin authored and dzarukin committed Dec 5, 2024
1 parent 19ef223 commit 2eb3dd1
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,27 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
brg->ldb = brg->load_dim / brg->ld_block;
brg->ldb_tail = brg->load_dim % brg->ld_block;

const int max_vpad = nstl::max(
brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad);

int adj_ld_block2 = calculate_ldb_params(brg, 4);
int max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2);

// reduce 'ld_block2' to allow a larger 'bd_block'
const int max_vpad = nstl::max(
brg->brgattr.max_top_vpad, brg->brgattr.max_bottom_vpad);
if (is_superset(brg->isa_impl, avx2) && max_bcast_block < max_vpad) {
adj_ld_block2 = calculate_ldb_params(brg, 2);
max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2);
for (int try_ld_block2 = 2; try_ld_block2 > 0; --try_ld_block2) {
adj_ld_block2 = calculate_ldb_params(brg, try_ld_block2);
max_bcast_block = calculate_max_bcast_block(brg, adj_ld_block2);
if (max_bcast_block >= max_vpad) break;
}
// bcast block in brgemm kernel should be greater than virtual
// padding to avoid possible functional issues
if (max_bcast_block < max_vpad) return status::unimplemented;
}

const int min_block = 1;
const int min_block = nstl::max(1, max_vpad);

float best_bd_block_eff = 0.f;
brg->bd_block = 1;
brg->bd_block = max_bcast_block;
for (int bd_block = max_bcast_block; bd_block >= min_block;
bd_block--) {
const auto bd_block_disb = static_cast<float>(brg->bcast_dim)
Expand Down

0 comments on commit 2eb3dd1

Please sign in to comment.