diff --git a/src/gpu/intel/jit/ir/core.cpp b/src/gpu/intel/jit/ir/core.cpp index f3f4e1075ad..0de95469837 100644 --- a/src/gpu/intel/jit/ir/core.cpp +++ b/src/gpu/intel/jit/ir/core.cpp @@ -296,9 +296,9 @@ void ptr_t::normalize(expr_t &base, expr_t &off, op_kind_t op_kind) { return; } - auto &base_off = base.as().off; + off = const_fold_non_recursive( + binary_op_t::make(op_kind, base.as().off, off)); base = base.as().base; - off = const_fold_non_recursive(binary_op_t::make(op_kind, base_off, off)); } expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) { @@ -310,11 +310,8 @@ expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) { void normalize_ptr(const type_t &type, expr_t &base_expr, expr_t &off) { if (base_expr.is()) { - auto &base = base_expr.as().base; - auto &base_off = base_expr.as().off; - - base_expr = base; - off = const_fold_non_recursive(base_off + off); + off = const_fold_non_recursive(base_expr.as().off + off); + base_expr = base_expr.as().base; } ir_assert(to_cpp(off) % type.scalar().size() == 0) << "Incompatible offset: " << off; diff --git a/src/gpu/intel/jit/ir/core.hpp b/src/gpu/intel/jit/ir/core.hpp index 984313c4e62..8ddfe2845b3 100644 --- a/src/gpu/intel/jit/ir/core.hpp +++ b/src/gpu/intel/jit/ir/core.hpp @@ -618,30 +618,27 @@ class object_impl_t { // Downcasts the object to the IR type, returns a reference. The IR type // must match the real IR type. + // N.B.: this can potentially be dangerous if applied to non-const objects, + // since assigning a different value to the source object might make + // the reference dangling due to the destruction of the former object; + // please only call this method on non-const objects if absolutely + // necessary, and please don't add a non-const variant of the method! template const T &as() const { - ir_assert(this->is()); - return *(const T *)this; - } - - template - T &as() { - ir_assert(this->is()); - return *(T *)this; + ir_assert(is()); + return *as_ptr(); // fails on incorrect casts even in Release } // Downcasts the object to the IR type, returns a pointer. If the IR type // doesn't match the real IR type, returns nullptr. + // N.B.: this can potentially be dangerous if applied to non-const objects, + // since assigning a different value to the source object might make + // the reference dangling due to the destruction of the former object; + // please only call this method on non-const objects if absolutely + // necessary, and please don't add a non-const variant of the method! template const T *as_ptr() const { - if (!this->is()) return nullptr; - return (const T *)this; - } - - template - T *as_ptr() { - if (!this->is()) return nullptr; - return (T *)this; + return (is()) ? (const T *)this : nullptr; } // Returns true if T matches the real IR type. @@ -722,24 +719,12 @@ class object_t { return impl_->as(); } - template - T &as() { - ir_assert(impl_); - return impl_->as(); - } - template const T *as_ptr() const { if (!impl_) return nullptr; return impl_->as_ptr(); } - template - T *as_ptr() { - if (!impl_) return nullptr; - return impl_->as_ptr(); - } - template bool is() const { if (is_empty()) return false; diff --git a/src/gpu/intel/jit/pass/simplify.cpp b/src/gpu/intel/jit/pass/simplify.cpp index 25daa3e5180..6c2356d41f7 100644 --- a/src/gpu/intel/jit/pass/simplify.cpp +++ b/src/gpu/intel/jit/pass/simplify.cpp @@ -1230,10 +1230,10 @@ class int_div_mod_expander_t : public nary_op_mutator_t { } expr_t mutate_with_add(const binary_op_t &obj) { - expr_t e = obj; - if (reduce_v1(e)) return e; - if (reduce_v2(e)) return e; - return e; + expr_t ret = reduce_v1(obj); + if (!ret.is_empty()) return ret; + ret = reduce_v2(obj); + return (!ret.is_empty()) ? ret : obj; } // Applies the following rules: @@ -1243,9 +1243,9 @@ class int_div_mod_expander_t : public nary_op_mutator_t { // 2) (A + B) / C -> (A / C) + (B / C), when // - A % C == 0 // - B >= 0 - bool reduce_v1(expr_t &expr) { + expr_t reduce_v1(const expr_t &expr) { auto *binary_op = expr.as_ptr(); - if (!binary_op) return false; + if (!binary_op) return expr_t(); auto op_kind = binary_op->op_kind; auto &a = binary_op->a; @@ -1263,41 +1263,36 @@ class int_div_mod_expander_t : public nary_op_mutator_t { } } - // Nothing to reduce, return expression as is. - if (lhs_args.empty()) return false; + // Nothing to reduce. + if (lhs_args.empty()) return expr_t(); auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args); auto _rhs = nary_op_back_transform(rhs_nary); bool rhs_ge_0 = cset.can_prove(_rhs >= 0); if (op_kind == op_kind_t::_mod) { - if (rhs_args.empty()) { - expr = to_expr(0, expr.type()); - return true; - } - if (!rhs_ge_0) return false; - expr = rhs_nary % b; - return true; + if (rhs_args.empty()) return to_expr(0, expr.type()); + if (!rhs_ge_0) return expr_t(); + return rhs_nary % b; } if (op_kind == op_kind_t::_div) { - if (!rhs_ge_0) return false; + if (!rhs_ge_0) return expr_t(); if (rhs_args.empty()) { - expr = mutate(lhs_args[0] / b); + expr_t ret = mutate(lhs_args[0] / b); for (int i = 1; i < int(lhs_args.size()); i++) { - expr += mutate(lhs_args[i] / b); + ret += mutate(lhs_args[i] / b); } - return true; + return ret; } auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b; auto rhs_div = rhs_nary / b; - expr = mutate(lhs_div) + mutate(rhs_div); - return true; + return mutate(lhs_div) + mutate(rhs_div); } ir_error_not_expected() << expr; - return false; + return expr_t(); } // Applies the following rules: @@ -1309,14 +1304,14 @@ class int_div_mod_expander_t : public nary_op_mutator_t { // - A > 0 // - C > 0 // - 0 <= D < A - bool reduce_v2(expr_t &expr) { + expr_t reduce_v2(const expr_t &expr) { auto *binary_op = expr.as_ptr(); - if (!binary_op) return false; + if (!binary_op) return expr_t(); auto op_kind = binary_op->op_kind; auto &a = binary_op->a; auto &b = binary_op->b; - if (!is_const(b)) return false; + if (!is_const(b)) return expr_t(); auto const_factor = [&](const expr_t &e) { auto _fe = factored_expr_t::make(e); @@ -1337,7 +1332,7 @@ class int_div_mod_expander_t : public nary_op_mutator_t { if (gcd > max_gcd) max_gcd = gcd; } - if (max_gcd == 0) return false; + if (max_gcd == 0) return expr_t(); std::vector lhs_args; // Reducible summands. std::vector rhs_args; // Non-reducible summands. @@ -1353,32 +1348,30 @@ class int_div_mod_expander_t : public nary_op_mutator_t { // reducible. ir_assert(!lhs_args.empty()); - if (rhs_args.empty()) return false; + if (rhs_args.empty()) return expr_t(); int64_t A = max_gcd; int64_t C = to_cpp(b) / A; - if (A <= 0 || C <= 0) return false; + if (A <= 0 || C <= 0) return expr_t(); auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args); auto D = nary_op_back_transform(rhs_nary); - if (!cset.can_prove(D >= 0) || !cset.can_prove(D < A)) return false; + if (!cset.can_prove(D >= 0) || !cset.can_prove(D < A)) return expr_t(); if (op_kind == op_kind_t::_mod) { auto lhs_mod = make_nary_op(op_kind_t::_add, lhs_args) % b; auto rhs_mod = rhs_nary % b; - expr = mutate(lhs_mod) + mutate(rhs_mod); - return true; + return mutate(lhs_mod) + mutate(rhs_mod); } if (op_kind == op_kind_t::_div) { auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b; - expr = lhs_div; - return true; + return lhs_div; } ir_error_not_expected() << expr; - return false; + return expr_t(); } bool is_div_reducible(const expr_t &a, const expr_t &b) const {