Skip to content

Commit

Permalink
gpu: jit: conv: avoid explicit cast when fixing 32-bit overflow if po…
Browse files Browse the repository at this point in the history
…ssible
  • Loading branch information
echeresh committed Apr 14, 2023
1 parent e3cb07d commit 31ac0e0
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 91 deletions.
1 change: 1 addition & 0 deletions src/gpu/jit/ir/ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ expr_t nary_op_back_transform(const expr_t &e);
expr_t nary_op_canonicalize(const expr_t &_e);
expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args);
std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e);
expr_t reorder_nary_add_args(const expr_t &e, bool x64_first);

// Substitutes all occurrences of `from` to `to` in `root`.
object_t substitute(const object_t &root, const object_t &from,
Expand Down
178 changes: 107 additions & 71 deletions src/gpu/jit/pass/overflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,100 @@ class overflow_bound_finder_t : public bound_finder_base_t {
object_map_t<expr_t, std::pair<int64_t, int64_t>> var_bounds_;
};

struct overflow_context_t {
overflow_bound_finder_t bound_finder;
object_map_t<expr_t, std::vector<expr_t>> vec_vars;
object_set_t<expr_t> vars_with_load;

bool contains_load(const expr_t &e) const {
if (!find_objects<load_t>(e).empty()) return true;
for (auto &v : find_objects<var_t>(e)) {
if (vars_with_load.count(v) != 0) return true;
}
return false;
}
};

class expr_overflow_fixer_t : public ir_mutator_t {
public:
expr_overflow_fixer_t(const overflow_context_t &ctx) : ctx_(ctx) {}

object_t _mutate(const binary_op_t &obj) override {
return mutate_expr(obj);
}

object_t _mutate(const unary_op_t &obj) override {
return mutate_expr(obj);
}

private:
template <typename T>
object_t mutate_expr(const T &obj) {
expr_t new_obj = ir_mutator_t::_mutate(obj);
if (!new_obj.type().is_x32()) return std::move(new_obj);
if (ctx_.contains_load(new_obj)) return std::move(new_obj);

bool found_overflow = false;
int elems = new_obj.type().elems();
for (int i = 0; i < elems; i++) {
expr_scalarizer_t scalarizer(elems, i, ctx_.vec_vars);
expr_t value = scalarizer.mutate(new_obj);
int64_t lo = ctx_.bound_finder.find_low_bound(value);
int64_t hi = ctx_.bound_finder.find_high_bound(value);
bool ok = bound_finder_base_t::is_good_bound(lo)
&& bound_finder_base_t::is_good_bound(hi);
if (ok) {
int64_t type_lo = value.type().is_s32()
? (int64_t)std::numeric_limits<int32_t>::min()
: (int64_t)std::numeric_limits<uint32_t>::min();
int64_t type_hi = value.type().is_s32()
? (int64_t)std::numeric_limits<int32_t>::max()
: (int64_t)std::numeric_limits<uint32_t>::max();

bool is_overflow = (lo < type_lo || hi > type_hi);
if (is_overflow) {
found_overflow = true;
ir_warning() << "Found overflow: " << value
<< " low bound: " << lo
<< " high bound: " << hi << std::endl;
break;
}
}
}
if (found_overflow) return fix_overflow(new_obj);
return std::move(new_obj);
}

static expr_t fix_overflow(const expr_t &e) {
auto *binary = e.as_ptr<binary_op_t>();
if (binary) {
return binary_op_t::make(binary->op_kind,
cast(binary->a, type_t::u64(e.type().elems())), binary->b);
}

ir_error_not_expected() << "Can't fix overflow: " << e;
return e;
}

const overflow_context_t &ctx_;
};

expr_t fix_expr_overflow(const expr_t &e, const overflow_context_t &ctx) {
auto e_fixed = expr_overflow_fixer_t(ctx).mutate(e);
if (e_fixed.is_same(e)) return e;

// Overflow detected, try to rearrange summands and avoid explicit casting.
auto nary = reorder_nary_add_args(
nary_op_canonicalize(e), /*x64_first=*/true);
auto e_reordered = nary_op_back_transform(nary);
auto e_reordered_fixed = expr_overflow_fixer_t(ctx).mutate(e_reordered);
if (e_reordered_fixed.is_same(e_reordered)) {
// No overflow detected after rearranging, return it.
return e_reordered;
}
return e_fixed;
}

class overflow_fixer_t : public ir_mutator_t {
public:
overflow_fixer_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {
Expand All @@ -90,7 +184,7 @@ class overflow_fixer_t : public ir_mutator_t {
<< to_string(rel.op_kind());
}
}
bound_finder_.set_var_bounds(kv.first, {lo, hi});
ctx_.bound_finder.set_var_bounds(kv.first, {lo, hi});
}
}

Expand All @@ -99,13 +193,13 @@ class overflow_fixer_t : public ir_mutator_t {
}

object_t _mutate(const binary_op_t &obj) override {
return mutate_expr(obj);
return fix_expr_overflow(obj, ctx_);
}

object_t _mutate(const for_t &obj) override {
auto lo = to_cpp<int64_t>(obj.init);
auto hi = to_cpp<int64_t>(obj.bound) - 1;
bound_finder_.set_var_bounds(obj.var, {lo, hi});
ctx_.bound_finder.set_var_bounds(obj.var, {lo, hi});
return ir_mutator_t::_mutate(obj);
}

Expand All @@ -114,25 +208,25 @@ class overflow_fixer_t : public ir_mutator_t {
if (!obj.var.type().is_int()) ok = false;
if (ok && obj.value.is_empty()) ok = false;
if (ok && obj.value.type().is_bool()) ok = false;
if (ok && bound_finder_.has_var(obj.var)) ok = false;
if (ok && ctx_.bound_finder.has_var(obj.var)) ok = false;

if (ok) {
if (contains_load(obj.value)) {
vars_with_load_.insert(obj.var);
if (ctx_.contains_load(obj.value)) {
ctx_.vars_with_load.insert(obj.var);
ok = false;
}
}

if (ok) {
int elems = obj.var.type().elems();
vec_vars_[obj.var].reserve(elems);
ctx_.vec_vars[obj.var].reserve(elems);
for (int i = 0; i < elems; i++) {
auto var_i = make_vec_var(obj.var, elems, i);
expr_scalarizer_t scalarizer(elems, i, vec_vars_);
expr_scalarizer_t scalarizer(elems, i, ctx_.vec_vars);
auto value_i = scalarizer.mutate(obj.value);
auto lo_hi = bound_finder_.find_bounds(value_i);
bound_finder_.set_var_bounds(var_i, lo_hi);
vec_vars_[obj.var].push_back(var_i);
auto lo_hi = ctx_.bound_finder.find_bounds(value_i);
ctx_.bound_finder.set_var_bounds(var_i, lo_hi);
ctx_.vec_vars[obj.var].push_back(var_i);
}
}
expr_t var = obj.var;
Expand All @@ -150,77 +244,19 @@ class overflow_fixer_t : public ir_mutator_t {
}

object_t _mutate(const unary_op_t &obj) override {
return mutate_expr(obj);
return fix_expr_overflow(obj, ctx_);
}

private:
template <typename T>
object_t mutate_expr(const T &obj) {
expr_t new_obj = ir_mutator_t::_mutate(obj);
if (!new_obj.type().is_x32()) return std::move(new_obj);
if (contains_load(new_obj)) return std::move(new_obj);

bool found_overflow = false;
int elems = new_obj.type().elems();
for (int i = 0; i < elems; i++) {
expr_scalarizer_t scalarizer(elems, i, vec_vars_);
expr_t value = scalarizer.mutate(new_obj);
int64_t lo = bound_finder_.find_low_bound(value);
int64_t hi = bound_finder_.find_high_bound(value);
bool ok = bound_finder_base_t::is_good_bound(lo)
&& bound_finder_base_t::is_good_bound(hi);
if (ok) {
int64_t type_lo = value.type().is_s32()
? (int64_t)std::numeric_limits<int32_t>::min()
: (int64_t)std::numeric_limits<uint32_t>::min();
int64_t type_hi = value.type().is_s32()
? (int64_t)std::numeric_limits<int32_t>::max()
: (int64_t)std::numeric_limits<uint32_t>::max();

bool is_overflow = (lo < type_lo || hi > type_hi);
if (is_overflow) {
found_overflow = true;
ir_warning() << "Found overflow: " << value
<< " low bound: " << lo
<< " high bound: " << hi << std::endl;
break;
}
}
}
if (found_overflow) return fix_overflow(new_obj);
return std::move(new_obj);
}

bool contains_load(const expr_t &e) const {
if (!find_objects<load_t>(e).empty()) return true;
for (auto &v : find_objects<var_t>(e)) {
if (vars_with_load_.count(v) != 0) return true;
}
return false;
}

static expr_t make_vec_var(const expr_t &_var, int elems, int idx) {
if (elems == 1) return _var;
auto &var = _var.as<var_t>();
auto vec_name = var.name + "_" + std::to_string(idx) + "_";
return var_t::make(var.type.scalar(), vec_name);
}

static expr_t fix_overflow(const expr_t &e) {
auto *binary = e.as_ptr<binary_op_t>();
if (binary) {
return binary_op_t::make(binary->op_kind,
cast(binary->a, type_t::u64(e.type().elems())), binary->b);
}

ir_error_not_expected() << "Can't fix overflow: " << e;
return e;
}

ir_context_t &ir_ctx_;
overflow_bound_finder_t bound_finder_;
object_map_t<expr_t, std::vector<expr_t>> vec_vars_;
object_set_t<expr_t> vars_with_load_;
overflow_context_t ctx_;
};

stmt_t fix_int32_overflow(const stmt_t &s, ir_context_t &ir_ctx) {
Expand Down
45 changes: 25 additions & 20 deletions src/gpu/jit/pass/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,30 @@ class common_factor_simplifier_t : public nary_op_mutator_t {
}
};

expr_t reorder_nary_add_args(const expr_t &e, bool x64_first) {
auto *nary_op = e.as_ptr<nary_op_t>();
if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2)
return e;

std::vector<expr_t> other_args;
std::vector<expr_t> x64_args;
for (auto &a : nary_op->args) {
if (a.type().is_x64()) {
x64_args.push_back(a);
} else {
other_args.push_back(a);
}
}

if (other_args.empty() || x64_args.empty()) return e;

std::vector<expr_t> new_args = std::move(other_args);
new_args.insert(x64_first ? new_args.begin() : new_args.end(),
x64_args.begin(), x64_args.end());

return nary_op_t::make(nary_op->op_kind, new_args);
}

// Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit
// arithmetic. Example:
// Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add]
Expand All @@ -1505,26 +1529,7 @@ class _64_bit_add_optimizer_t : public nary_op_mutator_t {
public:
object_t _mutate(const nary_op_t &obj) override {
auto new_obj = nary_op_mutator_t::_mutate(obj);
auto *nary_op = new_obj.as_ptr<nary_op_t>();
if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2)
return new_obj;

std::vector<expr_t> other_args;
std::vector<expr_t> x64_args;
for (auto &a : nary_op->args) {
if (a.type().is_x64()) {
x64_args.push_back(a);
} else {
other_args.push_back(a);
}
}

if (other_args.empty() || x64_args.empty()) return new_obj;

std::vector<expr_t> new_args = std::move(other_args);
new_args.insert(new_args.end(), x64_args.begin(), x64_args.end());

return nary_op_t::make(nary_op->op_kind, new_args);
return reorder_nary_add_args(new_obj, /*x64_first=*/false);
}
};

Expand Down

0 comments on commit 31ac0e0

Please # to comment.