Skip to content

Commit

Permalink
gpu: jit: conv: add workaround for incorrect GRF usage estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
echeresh committed Mar 24, 2023
1 parent 9197c01 commit 38153b9
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/gpu/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2095,6 +2095,13 @@ status_t init_cfg(conv_config_t &cfg, const convolution_pd_t *pd) {
conv_config_t::conv_config_t() = default;
conv_config_t::~conv_config_t() = default;

int conv_config_t::reserved_regs() const {
int ret = constants::reserved_regs_default;
// XXX: Workaround for incorrect register estimation.
if (prb().is_bwd_w && prb().mb % 16 != 0) ret += 4;
return ret;
}

void conv_config_t::override_set(const std::string &s, bool is_env) {
std::vector<param_t *> params;
for (auto &gp : get_params_)
Expand Down
4 changes: 2 additions & 2 deletions src/gpu/jit/conv/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ static const int max_slm_bufs = 3;

// GRF usage for kernel arguments, local work IDs/sizes, signal header,
// temporary expressions, etc.
static const int reserved_regs = 16;
static const int reserved_regs_default = 16;
} // namespace constants

struct conv_plan_t;
Expand Down Expand Up @@ -1125,7 +1125,7 @@ class conv_config_t {

int unroll(const std::string &name) const { return unroll()(name); }

int reserved_regs() const { return constants::reserved_regs; }
int reserved_regs() const;

const hw_config_t &hw_cfg() const { return exec_cfg().hw_cfg(); }

Expand Down
4 changes: 3 additions & 1 deletion src/gpu/jit/conv/config_plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@ void conv_plan_t::set_split(abc_kind_t abc, int factor) {
}

grf_usage_t conv_plan_t::grf_usage() const {
ir_assert(reserved_regs != -1);
bool with_headers = !reuse_headers;

int out_buf_regs = 0;
Expand Down Expand Up @@ -1023,7 +1024,7 @@ grf_usage_t conv_plan_t::grf_usage() const {
info.add(grf_usage_label_t::slm_load, slm_load_regs);
info.add(grf_usage_label_t::reorder, reorder_regs);
info.add(grf_usage_label_t::reused_headers, reused_header_regs);
info.add(grf_usage_label_t::reserved, constants::reserved_regs);
info.add(grf_usage_label_t::reserved, reserved_regs);
info.add(grf_usage_label_t::zero_points, zp_regs);
return info;
}
Expand Down Expand Up @@ -1681,6 +1682,7 @@ class plan_builder_t {

plan_status_t try_init_plan() {
plan_.reset();
plan_.reserved_regs = cfg_.reserved_regs();
PLAN_CHECK(init_x_g2r_direct_view(gemm_schedule_.a_tg_view(),
gemm_schedule_.a_thr_tile(), a_direct_view_));
PLAN_CHECK(init_x_g2r_direct_view(gemm_schedule_.b_tg_view(),
Expand Down
1 change: 1 addition & 0 deletions src/gpu/jit/conv/config_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ struct conv_plan_t : public base_plan_t {
int split_factor = 1;
bool reuse_headers = false;
int max_gmem_bufs = 0;
int reserved_regs = -1;

conv_plan_t(ngen::HW hw)
: base_plan_t(hw), slm(hw), prefetch(hw), x2r(hw), fma(hw), zp(hw) {}
Expand Down

0 comments on commit 38153b9

Please sign in to comment.