diff --git a/src/gpu/jit/conv/config.cpp b/src/gpu/jit/conv/config.cpp index eff59a07e61..dd153c199a0 100644 --- a/src/gpu/jit/conv/config.cpp +++ b/src/gpu/jit/conv/config.cpp @@ -2095,6 +2095,13 @@ status_t init_cfg(conv_config_t &cfg, const convolution_pd_t *pd) { conv_config_t::conv_config_t() = default; conv_config_t::~conv_config_t() = default; +int conv_config_t::reserved_regs() const { + int ret = constants::reserved_regs_default; + // XXX: Workaround for incorrect register estimation. + if (prb().is_bwd_w && prb().mb % 16 != 0) ret += 4; + return ret; +} + void conv_config_t::override_set(const std::string &s, bool is_env) { std::vector params; for (auto &gp : get_params_) diff --git a/src/gpu/jit/conv/config.hpp b/src/gpu/jit/conv/config.hpp index 84152bc3982..ee1074febbc 100644 --- a/src/gpu/jit/conv/config.hpp +++ b/src/gpu/jit/conv/config.hpp @@ -1007,7 +1007,7 @@ static const int max_slm_bufs = 3; // GRF usage for kernel arguments, local work IDs/sizes, signal header, // temporary expressions, etc. -static const int reserved_regs = 16; +static const int reserved_regs_default = 16; } // namespace constants struct conv_plan_t; @@ -1125,7 +1125,7 @@ class conv_config_t { int unroll(const std::string &name) const { return unroll()(name); } - int reserved_regs() const { return constants::reserved_regs; } + int reserved_regs() const; const hw_config_t &hw_cfg() const { return exec_cfg().hw_cfg(); } diff --git a/src/gpu/jit/conv/config_plan.cpp b/src/gpu/jit/conv/config_plan.cpp index 58c7a8b70af..0ac0bd2e4bc 100644 --- a/src/gpu/jit/conv/config_plan.cpp +++ b/src/gpu/jit/conv/config_plan.cpp @@ -937,6 +937,7 @@ void conv_plan_t::set_split(abc_kind_t abc, int factor) { } grf_usage_t conv_plan_t::grf_usage() const { + ir_assert(reserved_regs != -1); bool with_headers = !reuse_headers; int out_buf_regs = 0; @@ -1023,7 +1024,7 @@ grf_usage_t conv_plan_t::grf_usage() const { info.add(grf_usage_label_t::slm_load, slm_load_regs); info.add(grf_usage_label_t::reorder, reorder_regs); info.add(grf_usage_label_t::reused_headers, reused_header_regs); - info.add(grf_usage_label_t::reserved, constants::reserved_regs); + info.add(grf_usage_label_t::reserved, reserved_regs); info.add(grf_usage_label_t::zero_points, zp_regs); return info; } @@ -1681,6 +1682,7 @@ class plan_builder_t { plan_status_t try_init_plan() { plan_.reset(); + plan_.reserved_regs = cfg_.reserved_regs(); PLAN_CHECK(init_x_g2r_direct_view(gemm_schedule_.a_tg_view(), gemm_schedule_.a_thr_tile(), a_direct_view_)); PLAN_CHECK(init_x_g2r_direct_view(gemm_schedule_.b_tg_view(), diff --git a/src/gpu/jit/conv/config_plan.hpp b/src/gpu/jit/conv/config_plan.hpp index 840adc7ed18..57425c08ea9 100644 --- a/src/gpu/jit/conv/config_plan.hpp +++ b/src/gpu/jit/conv/config_plan.hpp @@ -221,6 +221,7 @@ struct conv_plan_t : public base_plan_t { int split_factor = 1; bool reuse_headers = false; int max_gmem_bufs = 0; + int reserved_regs = -1; conv_plan_t(ngen::HW hw) : base_plan_t(hw), slm(hw), prefetch(hw), x2r(hw), fma(hw), zp(hw) {}