Skip to content

Commit

Permalink
src: gpu: intel: jit: prohibit non-const as/as_ptr, fix dangling refs
Browse files Browse the repository at this point in the history
  • Loading branch information
hidefromkgb authored and karturov committed Jun 10, 2024
1 parent 5ea773e commit 5f5f0aa
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 69 deletions.
11 changes: 4 additions & 7 deletions src/gpu/intel/jit/ir/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ void ptr_t::normalize(expr_t &base, expr_t &off, op_kind_t op_kind) {
return;
}

auto &base_off = base.as<ptr_t>().off;
off = const_fold_non_recursive(
binary_op_t::make(op_kind, base.as<ptr_t>().off, off));
base = base.as<ptr_t>().base;
off = const_fold_non_recursive(binary_op_t::make(op_kind, base_off, off));
}

expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
Expand All @@ -310,11 +310,8 @@ expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) {

void normalize_ptr(const type_t &type, expr_t &base_expr, expr_t &off) {
if (base_expr.is<ptr_t>()) {
auto &base = base_expr.as<ptr_t>().base;
auto &base_off = base_expr.as<ptr_t>().off;

base_expr = base;
off = const_fold_non_recursive(base_off + off);
off = const_fold_non_recursive(base_expr.as<ptr_t>().off + off);
base_expr = base_expr.as<ptr_t>().base;
}
ir_assert(to_cpp<int64_t>(off) % type.scalar().size() == 0)
<< "Incompatible offset: " << off;
Expand Down
41 changes: 13 additions & 28 deletions src/gpu/intel/jit/ir/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,30 +618,27 @@ class object_impl_t {

// Downcasts the object to the IR type, returns a reference. The IR type
// must match the real IR type.
// N.B.: this can potentially be dangerous if applied to non-const objects,
// since assigning a different value to the source object might make
// the reference dangling due to the destruction of the former object;
// please only call this method on non-const objects if absolutely
// necessary, and please don't add a non-const variant of the method!
template <typename T>
const T &as() const {
ir_assert(this->is<T>());
return *(const T *)this;
}

template <typename T>
T &as() {
ir_assert(this->is<T>());
return *(T *)this;
ir_assert(is<T>());
return *as_ptr<T>(); // fails on incorrect casts even in Release
}

// Downcasts the object to the IR type, returns a pointer. If the IR type
// doesn't match the real IR type, returns nullptr.
// N.B.: this can potentially be dangerous if applied to non-const objects,
// since assigning a different value to the source object might make
// the reference dangling due to the destruction of the former object;
// please only call this method on non-const objects if absolutely
// necessary, and please don't add a non-const variant of the method!
template <typename T>
const T *as_ptr() const {
if (!this->is<T>()) return nullptr;
return (const T *)this;
}

template <typename T>
T *as_ptr() {
if (!this->is<T>()) return nullptr;
return (T *)this;
return (is<T>()) ? (const T *)this : nullptr;
}

// Returns true if T matches the real IR type.
Expand Down Expand Up @@ -722,24 +719,12 @@ class object_t {
return impl_->as<T>();
}

template <typename T>
T &as() {
ir_assert(impl_);
return impl_->as<T>();
}

template <typename T>
const T *as_ptr() const {
if (!impl_) return nullptr;
return impl_->as_ptr<T>();
}

template <typename T>
T *as_ptr() {
if (!impl_) return nullptr;
return impl_->as_ptr<T>();
}

template <typename T>
bool is() const {
if (is_empty()) return false;
Expand Down
61 changes: 27 additions & 34 deletions src/gpu/intel/jit/pass/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,10 +1230,10 @@ class int_div_mod_expander_t : public nary_op_mutator_t {
}

expr_t mutate_with_add(const binary_op_t &obj) {
expr_t e = obj;
if (reduce_v1(e)) return e;
if (reduce_v2(e)) return e;
return e;
expr_t ret = reduce_v1(obj);
if (!ret.is_empty()) return ret;
ret = reduce_v2(obj);
return (!ret.is_empty()) ? ret : obj;
}

// Applies the following rules:
Expand All @@ -1243,9 +1243,9 @@ class int_div_mod_expander_t : public nary_op_mutator_t {
// 2) (A + B) / C -> (A / C) + (B / C), when
// - A % C == 0
// - B >= 0
bool reduce_v1(expr_t &expr) {
expr_t reduce_v1(const expr_t &expr) {
auto *binary_op = expr.as_ptr<binary_op_t>();
if (!binary_op) return false;
if (!binary_op) return expr_t();

auto op_kind = binary_op->op_kind;
auto &a = binary_op->a;
Expand All @@ -1263,41 +1263,36 @@ class int_div_mod_expander_t : public nary_op_mutator_t {
}
}

// Nothing to reduce, return expression as is.
if (lhs_args.empty()) return false;
// Nothing to reduce.
if (lhs_args.empty()) return expr_t();

auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args);
auto _rhs = nary_op_back_transform(rhs_nary);
bool rhs_ge_0 = cset.can_prove(_rhs >= 0);

if (op_kind == op_kind_t::_mod) {
if (rhs_args.empty()) {
expr = to_expr(0, expr.type());
return true;
}
if (!rhs_ge_0) return false;
expr = rhs_nary % b;
return true;
if (rhs_args.empty()) return to_expr(0, expr.type());
if (!rhs_ge_0) return expr_t();
return rhs_nary % b;
}

if (op_kind == op_kind_t::_div) {
if (!rhs_ge_0) return false;
if (!rhs_ge_0) return expr_t();
if (rhs_args.empty()) {
expr = mutate(lhs_args[0] / b);
expr_t ret = mutate(lhs_args[0] / b);
for (int i = 1; i < int(lhs_args.size()); i++) {
expr += mutate(lhs_args[i] / b);
ret += mutate(lhs_args[i] / b);
}
return true;
return ret;
}
auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b;
auto rhs_div = rhs_nary / b;
expr = mutate(lhs_div) + mutate(rhs_div);
return true;
return mutate(lhs_div) + mutate(rhs_div);
}

ir_error_not_expected() << expr;

return false;
return expr_t();
}

// Applies the following rules:
Expand All @@ -1309,14 +1304,14 @@ class int_div_mod_expander_t : public nary_op_mutator_t {
// - A > 0
// - C > 0
// - 0 <= D < A
bool reduce_v2(expr_t &expr) {
expr_t reduce_v2(const expr_t &expr) {
auto *binary_op = expr.as_ptr<binary_op_t>();
if (!binary_op) return false;
if (!binary_op) return expr_t();

auto op_kind = binary_op->op_kind;
auto &a = binary_op->a;
auto &b = binary_op->b;
if (!is_const(b)) return false;
if (!is_const(b)) return expr_t();

auto const_factor = [&](const expr_t &e) {
auto _fe = factored_expr_t::make(e);
Expand All @@ -1337,7 +1332,7 @@ class int_div_mod_expander_t : public nary_op_mutator_t {
if (gcd > max_gcd) max_gcd = gcd;
}

if (max_gcd == 0) return false;
if (max_gcd == 0) return expr_t();

std::vector<expr_t> lhs_args; // Reducible summands.
std::vector<expr_t> rhs_args; // Non-reducible summands.
Expand All @@ -1353,32 +1348,30 @@ class int_div_mod_expander_t : public nary_op_mutator_t {
// reducible.
ir_assert(!lhs_args.empty());

if (rhs_args.empty()) return false;
if (rhs_args.empty()) return expr_t();

int64_t A = max_gcd;
int64_t C = to_cpp<int64_t>(b) / A;
if (A <= 0 || C <= 0) return false;
if (A <= 0 || C <= 0) return expr_t();

auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args);
auto D = nary_op_back_transform(rhs_nary);
if (!cset.can_prove(D >= 0) || !cset.can_prove(D < A)) return false;
if (!cset.can_prove(D >= 0) || !cset.can_prove(D < A)) return expr_t();

if (op_kind == op_kind_t::_mod) {
auto lhs_mod = make_nary_op(op_kind_t::_add, lhs_args) % b;
auto rhs_mod = rhs_nary % b;
expr = mutate(lhs_mod) + mutate(rhs_mod);
return true;
return mutate(lhs_mod) + mutate(rhs_mod);
}

if (op_kind == op_kind_t::_div) {
auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b;
expr = lhs_div;
return true;
return lhs_div;
}

ir_error_not_expected() << expr;

return false;
return expr_t();
}

bool is_div_reducible(const expr_t &a, const expr_t &b) const {
Expand Down

0 comments on commit 5f5f0aa

Please sign in to comment.