diff --git a/src/gpu/jit/ir/ir.hpp b/src/gpu/jit/ir/ir.hpp index a6a3126e831..fe7c7233086 100644 --- a/src/gpu/jit/ir/ir.hpp +++ b/src/gpu/jit/ir/ir.hpp @@ -410,6 +410,7 @@ expr_t nary_op_back_transform(const expr_t &e); expr_t nary_op_canonicalize(const expr_t &_e); expr_t make_nary_op(op_kind_t op_kind, const std::vector &args); std::vector cvt_expr_to_nary_op_args(const expr_t &e); +expr_t reorder_nary_add_args(const expr_t &e, bool x64_first); // Substitutes all occurrences of `from` to `to` in `root`. object_t substitute(const object_t &root, const object_t &from, diff --git a/src/gpu/jit/pass/overflow.cpp b/src/gpu/jit/pass/overflow.cpp index 22fbdad86d6..b3b149cfa72 100644 --- a/src/gpu/jit/pass/overflow.cpp +++ b/src/gpu/jit/pass/overflow.cpp @@ -70,6 +70,100 @@ class overflow_bound_finder_t : public bound_finder_base_t { object_map_t> var_bounds_; }; +struct overflow_context_t { + overflow_bound_finder_t bound_finder; + object_map_t> vec_vars; + object_set_t vars_with_load; + + bool contains_load(const expr_t &e) const { + if (!find_objects(e).empty()) return true; + for (auto &v : find_objects(e)) { + if (vars_with_load.count(v) != 0) return true; + } + return false; + } +}; + +class expr_overflow_fixer_t : public ir_mutator_t { +public: + expr_overflow_fixer_t(const overflow_context_t &ctx) : ctx_(ctx) {} + + object_t _mutate(const binary_op_t &obj) override { + return mutate_expr(obj); + } + + object_t _mutate(const unary_op_t &obj) override { + return mutate_expr(obj); + } + +private: + template + object_t mutate_expr(const T &obj) { + expr_t new_obj = ir_mutator_t::_mutate(obj); + if (!new_obj.type().is_x32()) return std::move(new_obj); + if (ctx_.contains_load(new_obj)) return std::move(new_obj); + + bool found_overflow = false; + int elems = new_obj.type().elems(); + for (int i = 0; i < elems; i++) { + expr_scalarizer_t scalarizer(elems, i, ctx_.vec_vars); + expr_t value = scalarizer.mutate(new_obj); + int64_t lo = ctx_.bound_finder.find_low_bound(value); + int64_t hi = ctx_.bound_finder.find_high_bound(value); + bool ok = bound_finder_base_t::is_good_bound(lo) + && bound_finder_base_t::is_good_bound(hi); + if (ok) { + int64_t type_lo = value.type().is_s32() + ? (int64_t)std::numeric_limits::min() + : (int64_t)std::numeric_limits::min(); + int64_t type_hi = value.type().is_s32() + ? (int64_t)std::numeric_limits::max() + : (int64_t)std::numeric_limits::max(); + + bool is_overflow = (lo < type_lo || hi > type_hi); + if (is_overflow) { + found_overflow = true; + ir_warning() << "Found overflow: " << value + << " low bound: " << lo + << " high bound: " << hi << std::endl; + break; + } + } + } + if (found_overflow) return fix_overflow(new_obj); + return std::move(new_obj); + } + + static expr_t fix_overflow(const expr_t &e) { + auto *binary = e.as_ptr(); + if (binary) { + return binary_op_t::make(binary->op_kind, + cast(binary->a, type_t::u64(e.type().elems())), binary->b); + } + + ir_error_not_expected() << "Can't fix overflow: " << e; + return e; + } + + const overflow_context_t &ctx_; +}; + +expr_t fix_expr_overflow(const expr_t &e, const overflow_context_t &ctx) { + auto e_fixed = expr_overflow_fixer_t(ctx).mutate(e); + if (e_fixed.is_same(e)) return e; + + // Overflow detected, try to rearrange summands and avoid explicit casting. + auto nary = reorder_nary_add_args( + nary_op_canonicalize(e), /*x64_first=*/true); + auto e_reordered = nary_op_back_transform(nary); + auto e_reordered_fixed = expr_overflow_fixer_t(ctx).mutate(e_reordered); + if (e_reordered_fixed.is_same(e_reordered)) { + // No overflow detected after rearranging, return it. + return e_reordered; + } + return e_fixed; +} + class overflow_fixer_t : public ir_mutator_t { public: overflow_fixer_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) { @@ -90,7 +184,7 @@ class overflow_fixer_t : public ir_mutator_t { << to_string(rel.op_kind()); } } - bound_finder_.set_var_bounds(kv.first, {lo, hi}); + ctx_.bound_finder.set_var_bounds(kv.first, {lo, hi}); } } @@ -99,13 +193,13 @@ class overflow_fixer_t : public ir_mutator_t { } object_t _mutate(const binary_op_t &obj) override { - return mutate_expr(obj); + return fix_expr_overflow(obj, ctx_); } object_t _mutate(const for_t &obj) override { auto lo = to_cpp(obj.init); auto hi = to_cpp(obj.bound) - 1; - bound_finder_.set_var_bounds(obj.var, {lo, hi}); + ctx_.bound_finder.set_var_bounds(obj.var, {lo, hi}); return ir_mutator_t::_mutate(obj); } @@ -114,25 +208,25 @@ class overflow_fixer_t : public ir_mutator_t { if (!obj.var.type().is_int()) ok = false; if (ok && obj.value.is_empty()) ok = false; if (ok && obj.value.type().is_bool()) ok = false; - if (ok && bound_finder_.has_var(obj.var)) ok = false; + if (ok && ctx_.bound_finder.has_var(obj.var)) ok = false; if (ok) { - if (contains_load(obj.value)) { - vars_with_load_.insert(obj.var); + if (ctx_.contains_load(obj.value)) { + ctx_.vars_with_load.insert(obj.var); ok = false; } } if (ok) { int elems = obj.var.type().elems(); - vec_vars_[obj.var].reserve(elems); + ctx_.vec_vars[obj.var].reserve(elems); for (int i = 0; i < elems; i++) { auto var_i = make_vec_var(obj.var, elems, i); - expr_scalarizer_t scalarizer(elems, i, vec_vars_); + expr_scalarizer_t scalarizer(elems, i, ctx_.vec_vars); auto value_i = scalarizer.mutate(obj.value); - auto lo_hi = bound_finder_.find_bounds(value_i); - bound_finder_.set_var_bounds(var_i, lo_hi); - vec_vars_[obj.var].push_back(var_i); + auto lo_hi = ctx_.bound_finder.find_bounds(value_i); + ctx_.bound_finder.set_var_bounds(var_i, lo_hi); + ctx_.vec_vars[obj.var].push_back(var_i); } } expr_t var = obj.var; @@ -150,55 +244,10 @@ class overflow_fixer_t : public ir_mutator_t { } object_t _mutate(const unary_op_t &obj) override { - return mutate_expr(obj); + return fix_expr_overflow(obj, ctx_); } private: - template - object_t mutate_expr(const T &obj) { - expr_t new_obj = ir_mutator_t::_mutate(obj); - if (!new_obj.type().is_x32()) return std::move(new_obj); - if (contains_load(new_obj)) return std::move(new_obj); - - bool found_overflow = false; - int elems = new_obj.type().elems(); - for (int i = 0; i < elems; i++) { - expr_scalarizer_t scalarizer(elems, i, vec_vars_); - expr_t value = scalarizer.mutate(new_obj); - int64_t lo = bound_finder_.find_low_bound(value); - int64_t hi = bound_finder_.find_high_bound(value); - bool ok = bound_finder_base_t::is_good_bound(lo) - && bound_finder_base_t::is_good_bound(hi); - if (ok) { - int64_t type_lo = value.type().is_s32() - ? (int64_t)std::numeric_limits::min() - : (int64_t)std::numeric_limits::min(); - int64_t type_hi = value.type().is_s32() - ? (int64_t)std::numeric_limits::max() - : (int64_t)std::numeric_limits::max(); - - bool is_overflow = (lo < type_lo || hi > type_hi); - if (is_overflow) { - found_overflow = true; - ir_warning() << "Found overflow: " << value - << " low bound: " << lo - << " high bound: " << hi << std::endl; - break; - } - } - } - if (found_overflow) return fix_overflow(new_obj); - return std::move(new_obj); - } - - bool contains_load(const expr_t &e) const { - if (!find_objects(e).empty()) return true; - for (auto &v : find_objects(e)) { - if (vars_with_load_.count(v) != 0) return true; - } - return false; - } - static expr_t make_vec_var(const expr_t &_var, int elems, int idx) { if (elems == 1) return _var; auto &var = _var.as(); @@ -206,21 +255,8 @@ class overflow_fixer_t : public ir_mutator_t { return var_t::make(var.type.scalar(), vec_name); } - static expr_t fix_overflow(const expr_t &e) { - auto *binary = e.as_ptr(); - if (binary) { - return binary_op_t::make(binary->op_kind, - cast(binary->a, type_t::u64(e.type().elems())), binary->b); - } - - ir_error_not_expected() << "Can't fix overflow: " << e; - return e; - } - ir_context_t &ir_ctx_; - overflow_bound_finder_t bound_finder_; - object_map_t> vec_vars_; - object_set_t vars_with_load_; + overflow_context_t ctx_; }; stmt_t fix_int32_overflow(const stmt_t &s, ir_context_t &ir_ctx) { diff --git a/src/gpu/jit/pass/simplify.cpp b/src/gpu/jit/pass/simplify.cpp index 9c3557ea963..90e64f8a378 100644 --- a/src/gpu/jit/pass/simplify.cpp +++ b/src/gpu/jit/pass/simplify.cpp @@ -1497,6 +1497,30 @@ class common_factor_simplifier_t : public nary_op_mutator_t { } }; +expr_t reorder_nary_add_args(const expr_t &e, bool x64_first) { + auto *nary_op = e.as_ptr(); + if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2) + return e; + + std::vector other_args; + std::vector x64_args; + for (auto &a : nary_op->args) { + if (a.type().is_x64()) { + x64_args.push_back(a); + } else { + other_args.push_back(a); + } + } + + if (other_args.empty() || x64_args.empty()) return e; + + std::vector new_args = std::move(other_args); + new_args.insert(x64_first ? new_args.begin() : new_args.end(), + x64_args.begin(), x64_args.end()); + + return nary_op_t::make(nary_op->op_kind, new_args); +} + // Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit // arithmetic. Example: // Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add] @@ -1505,26 +1529,7 @@ class _64_bit_add_optimizer_t : public nary_op_mutator_t { public: object_t _mutate(const nary_op_t &obj) override { auto new_obj = nary_op_mutator_t::_mutate(obj); - auto *nary_op = new_obj.as_ptr(); - if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2) - return new_obj; - - std::vector other_args; - std::vector x64_args; - for (auto &a : nary_op->args) { - if (a.type().is_x64()) { - x64_args.push_back(a); - } else { - other_args.push_back(a); - } - } - - if (other_args.empty() || x64_args.empty()) return new_obj; - - std::vector new_args = std::move(other_args); - new_args.insert(new_args.end(), x64_args.begin(), x64_args.end()); - - return nary_op_t::make(nary_op->op_kind, new_args); + return reorder_nary_add_args(new_obj, /*x64_first=*/false); } };