From 91071b8d2b4257aa62ecc4d460c2945197bcb05e Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Tue, 7 Nov 2023 14:10:48 +0000 Subject: [PATCH] use a more compact SMT encoding for dynamic FP rounding --- ir/instr.cpp | 101 +++++++++++++++++++++++++-------------------------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/ir/instr.cpp b/ir/instr.cpp index 990648d19..27faef2a6 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -672,27 +672,27 @@ static expr handle_subnormal(const State &s, FPDenormalAttrs::Type attr, template static T round_value(const State &s, FpRoundingMode rm, AndExpr &non_poison, - const function &fn) { + const function &fn) { if (rm.isDefault()) - return fn(FpRoundingMode::RNE); + return fn(expr::rne()); auto &var = s.getFpRoundingMode(); if (!rm.isDynamic()) { non_poison.add(var == rm.getMode()); - return fn(rm); + return fn(rm.toSMT()); } - return T::mkIf(var == FpRoundingMode::RNE, fn(FpRoundingMode::RNE), - T::mkIf(var == FpRoundingMode::RNA, fn(FpRoundingMode::RNA), - T::mkIf(var == FpRoundingMode::RTP, fn(FpRoundingMode::RTP), - T::mkIf(var == FpRoundingMode::RTN, fn(FpRoundingMode::RTN), - fn(FpRoundingMode::RTZ))))); + return fn(expr::mkIf(var == FpRoundingMode::RNE, expr::rne(), + expr::mkIf(var == FpRoundingMode::RNA, expr::rna(), + expr::mkIf(var == FpRoundingMode::RTP, expr::rtp(), + expr::mkIf(var == FpRoundingMode::RTN, expr::rtn(), + expr::rtz()))))); } static StateValue fm_poison(State &s, expr a, const expr &ap, expr b, const expr &bp, expr c, const expr &cp, function fn, + const expr&, const expr&)> fn, const Type &ty, FastMathFlags fmath, FpRoundingMode rm, bool bitwise, bool flags_in_only = false, int nary = 3) { @@ -728,7 +728,7 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, expr b, fp_c = handle_subnormal(s, fpdenormal, std::move(fp_c)); } - function fn_rm + function fn_rm = [&](auto rm) { return fn(fp_a, fp_b, fp_c, rm); }; expr val = bitwise ? fn(a, b, c, {}) : round_value(s, rm, non_poison, fn_rm); @@ -781,7 +781,7 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, expr b, static StateValue fm_poison(State &s, expr a, const expr &ap, expr b, const expr &bp, function fn, + const expr&)> fn, const Type &ty, FastMathFlags fmath, FpRoundingMode rm, bool bitwise, bool flags_in_only = false) { @@ -792,7 +792,7 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, expr b, } static StateValue fm_poison(State &s, expr a, const expr &ap, - function fn, + function fn, const Type &ty, FastMathFlags fmath, FpRoundingMode rm, bool bitwise, bool flags_in_only = false) { @@ -802,36 +802,36 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, } StateValue FpBinOp::toSMT(State &s) const { - function fn; + function fn; bool bitwise = false; switch (op) { case FAdd: - fn = [](const expr &a, const expr &b, FpRoundingMode rm) { - return a.fadd(b, rm.toSMT()); + fn = [](const expr &a, const expr &b, const expr &rm) { + return a.fadd(b, rm); }; break; case FSub: - fn = [](const expr &a, const expr &b, FpRoundingMode rm) { - return a.fsub(b, rm.toSMT()); + fn = [](const expr &a, const expr &b, const expr &rm) { + return a.fsub(b, rm); }; break; case FMul: - fn = [](const expr &a, const expr &b, FpRoundingMode rm) { - return a.fmul(b, rm.toSMT()); + fn = [](const expr &a, const expr &b, const expr &rm) { + return a.fmul(b, rm); }; break; case FDiv: - fn = [](const expr &a, const expr &b, FpRoundingMode rm) { - return a.fdiv(b, rm.toSMT()); + fn = [](const expr &a, const expr &b, const expr &rm) { + return a.fdiv(b, rm); }; break; case FRem: - fn = [&](const expr &a, const expr &b, FpRoundingMode rm) { + fn = [&](const expr &a, const expr &b, const expr &rm) { // TODO; Z3 has no support for LLVM's frem which is actually an fmod auto val = a.frem(b); s.doesApproximation("frem", val); @@ -841,7 +841,7 @@ StateValue FpBinOp::toSMT(State &s) const { case FMin: case FMax: - fn = [&](const expr &a, const expr &b, FpRoundingMode rm) { + fn = [&](const expr &a, const expr &b, const expr &rm) { expr ndet = s.getFreshNondetVar("maxminnondet", true); auto ndz = expr::mkIf(ndet, expr::mkNumber("0", a), expr::mkNumber("-0", a)); @@ -857,7 +857,7 @@ StateValue FpBinOp::toSMT(State &s) const { case FMinimum: case FMaximum: - fn = [&](const expr &a, const expr &b, FpRoundingMode rm) { + fn = [&](const expr &a, const expr &b, const expr &rm) { expr zpos = expr::mkNumber("0", a), zneg = expr::mkNumber("-0", a); expr cmp = (op == FMinimum) ? a.fole(b) : a.foge(b); expr neg_cond = op == FMinimum ? (a.isFPNegative() || b.isFPNegative()) @@ -872,7 +872,7 @@ StateValue FpBinOp::toSMT(State &s) const { case CopySign: bitwise = true; - fn = [](const expr &a, const expr &b, FpRoundingMode rm) { + fn = [](const expr &a, const expr &b, const expr &rm) { return a.copysign(b); }; break; @@ -1086,43 +1086,43 @@ void FpUnaryOp::print(ostream &os) const { } StateValue FpUnaryOp::toSMT(State &s) const { - expr (*fn)(const expr&, FpRoundingMode) = nullptr; + expr (*fn)(const expr&, const expr&) = nullptr; bool bitwise = false; switch (op) { case FAbs: bitwise = true; - fn = [](const expr &v, FpRoundingMode rm) { return v.fabs(); }; + fn = [](const expr &v, const expr &rm) { return v.fabs(); }; break; case FNeg: bitwise = true; - fn = [](const expr &v, FpRoundingMode rm) { return v.fneg(); }; + fn = [](const expr &v, const expr &rm) { return v.fneg(); }; break; case Canonicalize: - fn = [](const expr &v, FpRoundingMode rm) { return v; }; + fn = [](const expr &v, const expr &rm) { return v; }; break; case Ceil: - fn = [](const expr &v, FpRoundingMode rm) { return v.ceil(); }; + fn = [](const expr &v, const expr &rm) { return v.ceil(); }; break; case Floor: - fn = [](const expr &v, FpRoundingMode rm) { return v.floor(); }; + fn = [](const expr &v, const expr &rm) { return v.floor(); }; break; case RInt: case NearbyInt: // TODO: they differ in exception behavior - fn = [](const expr &v, FpRoundingMode rm) { return v.round(rm.toSMT()); }; + fn = [](const expr &v, const expr &rm) { return v.round(rm); }; break; case Round: - fn = [](const expr &v, FpRoundingMode rm) { return v.round(expr::rna()); }; + fn = [](const expr &v, const expr &rm) { return v.round(expr::rna()); }; break; case RoundEven: - fn = [](const expr &v, FpRoundingMode rm) { return v.round(expr::rne()); }; + fn = [](const expr &v, const expr &rm) { return v.round(expr::rne()); }; break; case Trunc: - fn = [](const expr &v, FpRoundingMode rm) { return v.round(expr::rtz()); }; + fn = [](const expr &v, const expr &rm) { return v.round(expr::rtz()); }; break; case Sqrt: - fn = [](const expr &v, FpRoundingMode rm) { return v.sqrt(rm.toSMT()); }; + fn = [](const expr &v, const expr &rm) { return v.sqrt(rm); }; break; } @@ -1383,17 +1383,16 @@ void FpTernaryOp::print(ostream &os) const { } StateValue FpTernaryOp::toSMT(State &s) const { - function fn; + function fn; switch (op) { case FMA: - fn = [](const expr &a, const expr &b, const expr &c, FpRoundingMode rm) { - return expr::fma(a, b, c, rm.toSMT()); + fn = [](const expr &a, const expr &b, const expr &c, const expr &rm) { + return expr::fma(a, b, c, rm); }; break; case MulAdd: - fn = [&](const expr &a, const expr &b, const expr &c, FpRoundingMode rm0) { - auto rm = rm0.toSMT(); + fn = [&](const expr &a, const expr &b, const expr &c, const expr &rm) { expr var = s.getFreshNondetVar("nondet", expr(false)); return expr::mkIf(var, expr::fma(a, b, c, rm), a.fmul(b, rm).fadd(c, rm)); }; @@ -1682,19 +1681,19 @@ void FpConversionOp::print(ostream &os) const { StateValue FpConversionOp::toSMT(State &s) const { auto &v = s[*val]; - function fn; + function fn; switch (op) { case SIntToFP: fn = [](auto &val, auto &to_type, auto rm) -> StateValue { - return { val.sint2fp(to_type.getAsFloatType()->getDummyFloat(), - rm.toSMT()), true }; + return { val.sint2fp(to_type.getAsFloatType()->getDummyFloat(), rm), + true }; }; break; case UIntToFP: fn = [](auto &val, auto &to_type, auto rm) -> StateValue { - return { val.uint2fp(to_type.getAsFloatType()->getDummyFloat(), - rm.toSMT()), true }; + return { val.uint2fp(to_type.getAsFloatType()->getDummyFloat(), rm), + true }; }; break; case FPToSInt: @@ -1709,7 +1708,7 @@ StateValue FpConversionOp::toSMT(State &s) const { is_poison = true; break; case LRInt: - rm = rm_in.toSMT(); + rm = rm_in; break; case LRound: rm = expr::rna(); @@ -1746,8 +1745,8 @@ StateValue FpConversionOp::toSMT(State &s) const { case FPExt: case FPTrunc: fn = [](auto &val, auto &to_type, auto rm) -> StateValue { - return { val.float2Float(to_type.getAsFloatType()->getDummyFloat(), - rm.toSMT()), true }; + return { val.float2Float(to_type.getAsFloatType()->getDummyFloat(), rm), + true }; }; break; } @@ -1761,13 +1760,13 @@ StateValue FpConversionOp::toSMT(State &s) const { val = ty->getFloat(val); } - function fn_rm + function fn_rm = [&](auto rm) { return fn(val, to_type, rm); }; AndExpr np; np.add(sv.non_poison); StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm) - : fn(val, to_type, rm); + : fn(val, to_type, rm.toSMT()); np.add(std::move(ret.non_poison)); return { to_type.isFloatType()