Skip to content

Commit

Permalink
x64: brgemm convolution: update req_cal_comp_pad condition
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed Nov 18, 2024
1 parent 48f6bd9 commit 05d68df
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/cpu/x64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,10 @@ status_t brgemm_convolution_fwd_t<isa>::cal_compensation(

const int max_ker_sz = adjusted_k.size();
const auto comp_buffer_ow = jcp.exec_type != exec_vpad ? jcp.ow : 1;
// TODO: revise the thread distribution here because the work_amount may be
// insufficient
// TODO: revise comp_vpad_pbuffer_ generator to avoid huge code for cases
// with big ow
const auto work_amount
= static_cast<dim_t>(jcp.ngroups) * jcp.nb_oc * max_ker_sz;
const auto is_small_shape = work_amount <= jcp.nthr
Expand Down
12 changes: 9 additions & 3 deletions src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,12 +2296,18 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,

// For padding shapes, we calculate the comp along with the computation
// inside brgemm kernel when output size is small to get optimal perf
// Or we calculate the comp using brgemm_coomp_pad kernel
// For shapes with large ow we calculate the comp inside brgemm kernel too
// because current implementation of brgemm_comp_pad kernel unrolled by ow
// so not optimal for large ow.
// Otherwise we calculate the comp using brgemm_comp_pad kernel
const auto output_sz = static_cast<dim_t>(jcp.mb) * jcp.ngroups * jcp.oc
* jcp.od * jcp.oh * jcp.ow;
// TODO: revise below condition to avoid limitation for big ow
const auto shape_for_brgemm_kernel
= (output_sz <= 8192 && jcp.oc < 512) || jcp.ow > 128;
const auto is_relo = jcp.is_relo() && jcp.relo_conv_weights;
jcp.req_brg_comp_pad = compensation_w_padding && jcp.exec_type != exec_trans
&& IMPLICATION(!(jcp.is_relo() && jcp.relo_conv_weights),
output_sz <= 8192 && jcp.oc < 512);
&& IMPLICATION(!is_relo, shape_for_brgemm_kernel);
jcp.req_cal_comp_pad = compensation_w_padding && !jcp.req_brg_comp_pad
&& IMPLICATION(jcp.exec_type == exec_vpad,
jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0
Expand Down

0 comments on commit 05d68df

Please sign in to comment.