Skip to content

Commit

Permalink
use a more compact SMT encoding for dynamic FP rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
nunoplopes committed Nov 7, 2023
1 parent 43c5183 commit 91071b8
Showing 1 changed file with 50 additions and 51 deletions.
101 changes: 50 additions & 51 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,27 +672,27 @@ static expr handle_subnormal(const State &s, FPDenormalAttrs::Type attr,

template <typename T>
static T round_value(const State &s, FpRoundingMode rm, AndExpr &non_poison,
const function<T(FpRoundingMode)> &fn) {
const function<T(const expr&)> &fn) {
if (rm.isDefault())
return fn(FpRoundingMode::RNE);
return fn(expr::rne());

auto &var = s.getFpRoundingMode();
if (!rm.isDynamic()) {
non_poison.add(var == rm.getMode());
return fn(rm);
return fn(rm.toSMT());
}

return T::mkIf(var == FpRoundingMode::RNE, fn(FpRoundingMode::RNE),
T::mkIf(var == FpRoundingMode::RNA, fn(FpRoundingMode::RNA),
T::mkIf(var == FpRoundingMode::RTP, fn(FpRoundingMode::RTP),
T::mkIf(var == FpRoundingMode::RTN, fn(FpRoundingMode::RTN),
fn(FpRoundingMode::RTZ)))));
return fn(expr::mkIf(var == FpRoundingMode::RNE, expr::rne(),
expr::mkIf(var == FpRoundingMode::RNA, expr::rna(),
expr::mkIf(var == FpRoundingMode::RTP, expr::rtp(),
expr::mkIf(var == FpRoundingMode::RTN, expr::rtn(),
expr::rtz())))));
}

static StateValue fm_poison(State &s, expr a, const expr &ap, expr b,
const expr &bp, expr c, const expr &cp,
function<expr(const expr&, const expr&,
const expr&, FpRoundingMode)> fn,
const expr&, const expr&)> fn,
const Type &ty, FastMathFlags fmath,
FpRoundingMode rm, bool bitwise,
bool flags_in_only = false, int nary = 3) {
Expand Down Expand Up @@ -728,7 +728,7 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, expr b,
fp_c = handle_subnormal(s, fpdenormal, std::move(fp_c));
}

function<expr(FpRoundingMode)> fn_rm
function<expr(const expr&)> fn_rm
= [&](auto rm) { return fn(fp_a, fp_b, fp_c, rm); };
expr val = bitwise ? fn(a, b, c, {}) : round_value(s, rm, non_poison, fn_rm);

Expand Down Expand Up @@ -781,7 +781,7 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, expr b,
static StateValue fm_poison(State &s, expr a, const expr &ap, expr b,
const expr &bp,
function<expr(const expr&, const expr&,
FpRoundingMode)> fn,
const expr&)> fn,
const Type &ty, FastMathFlags fmath,
FpRoundingMode rm, bool bitwise,
bool flags_in_only = false) {
Expand All @@ -792,7 +792,7 @@ static StateValue fm_poison(State &s, expr a, const expr &ap, expr b,
}

static StateValue fm_poison(State &s, expr a, const expr &ap,
function<expr(const expr&, FpRoundingMode)> fn,
function<expr(const expr&, const expr&)> fn,
const Type &ty, FastMathFlags fmath,
FpRoundingMode rm, bool bitwise,
bool flags_in_only = false) {
Expand All @@ -802,36 +802,36 @@ static StateValue fm_poison(State &s, expr a, const expr &ap,
}

StateValue FpBinOp::toSMT(State &s) const {
function<expr(const expr&, const expr&, FpRoundingMode)> fn;
function<expr(const expr&, const expr&, const expr&)> fn;
bool bitwise = false;

switch (op) {
case FAdd:
fn = [](const expr &a, const expr &b, FpRoundingMode rm) {
return a.fadd(b, rm.toSMT());
fn = [](const expr &a, const expr &b, const expr &rm) {
return a.fadd(b, rm);
};
break;

case FSub:
fn = [](const expr &a, const expr &b, FpRoundingMode rm) {
return a.fsub(b, rm.toSMT());
fn = [](const expr &a, const expr &b, const expr &rm) {
return a.fsub(b, rm);
};
break;

case FMul:
fn = [](const expr &a, const expr &b, FpRoundingMode rm) {
return a.fmul(b, rm.toSMT());
fn = [](const expr &a, const expr &b, const expr &rm) {
return a.fmul(b, rm);
};
break;

case FDiv:
fn = [](const expr &a, const expr &b, FpRoundingMode rm) {
return a.fdiv(b, rm.toSMT());
fn = [](const expr &a, const expr &b, const expr &rm) {
return a.fdiv(b, rm);
};
break;

case FRem:
fn = [&](const expr &a, const expr &b, FpRoundingMode rm) {
fn = [&](const expr &a, const expr &b, const expr &rm) {
// TODO; Z3 has no support for LLVM's frem which is actually an fmod
auto val = a.frem(b);
s.doesApproximation("frem", val);
Expand All @@ -841,7 +841,7 @@ StateValue FpBinOp::toSMT(State &s) const {

case FMin:
case FMax:
fn = [&](const expr &a, const expr &b, FpRoundingMode rm) {
fn = [&](const expr &a, const expr &b, const expr &rm) {
expr ndet = s.getFreshNondetVar("maxminnondet", true);
auto ndz = expr::mkIf(ndet, expr::mkNumber("0", a),
expr::mkNumber("-0", a));
Expand All @@ -857,7 +857,7 @@ StateValue FpBinOp::toSMT(State &s) const {

case FMinimum:
case FMaximum:
fn = [&](const expr &a, const expr &b, FpRoundingMode rm) {
fn = [&](const expr &a, const expr &b, const expr &rm) {
expr zpos = expr::mkNumber("0", a), zneg = expr::mkNumber("-0", a);
expr cmp = (op == FMinimum) ? a.fole(b) : a.foge(b);
expr neg_cond = op == FMinimum ? (a.isFPNegative() || b.isFPNegative())
Expand All @@ -872,7 +872,7 @@ StateValue FpBinOp::toSMT(State &s) const {

case CopySign:
bitwise = true;
fn = [](const expr &a, const expr &b, FpRoundingMode rm) {
fn = [](const expr &a, const expr &b, const expr &rm) {
return a.copysign(b);
};
break;
Expand Down Expand Up @@ -1086,43 +1086,43 @@ void FpUnaryOp::print(ostream &os) const {
}

StateValue FpUnaryOp::toSMT(State &s) const {
expr (*fn)(const expr&, FpRoundingMode) = nullptr;
expr (*fn)(const expr&, const expr&) = nullptr;
bool bitwise = false;

switch (op) {
case FAbs:
bitwise = true;
fn = [](const expr &v, FpRoundingMode rm) { return v.fabs(); };
fn = [](const expr &v, const expr &rm) { return v.fabs(); };
break;
case FNeg:
bitwise = true;
fn = [](const expr &v, FpRoundingMode rm) { return v.fneg(); };
fn = [](const expr &v, const expr &rm) { return v.fneg(); };
break;
case Canonicalize:
fn = [](const expr &v, FpRoundingMode rm) { return v; };
fn = [](const expr &v, const expr &rm) { return v; };
break;
case Ceil:
fn = [](const expr &v, FpRoundingMode rm) { return v.ceil(); };
fn = [](const expr &v, const expr &rm) { return v.ceil(); };
break;
case Floor:
fn = [](const expr &v, FpRoundingMode rm) { return v.floor(); };
fn = [](const expr &v, const expr &rm) { return v.floor(); };
break;
case RInt:
case NearbyInt:
// TODO: they differ in exception behavior
fn = [](const expr &v, FpRoundingMode rm) { return v.round(rm.toSMT()); };
fn = [](const expr &v, const expr &rm) { return v.round(rm); };
break;
case Round:
fn = [](const expr &v, FpRoundingMode rm) { return v.round(expr::rna()); };
fn = [](const expr &v, const expr &rm) { return v.round(expr::rna()); };
break;
case RoundEven:
fn = [](const expr &v, FpRoundingMode rm) { return v.round(expr::rne()); };
fn = [](const expr &v, const expr &rm) { return v.round(expr::rne()); };
break;
case Trunc:
fn = [](const expr &v, FpRoundingMode rm) { return v.round(expr::rtz()); };
fn = [](const expr &v, const expr &rm) { return v.round(expr::rtz()); };
break;
case Sqrt:
fn = [](const expr &v, FpRoundingMode rm) { return v.sqrt(rm.toSMT()); };
fn = [](const expr &v, const expr &rm) { return v.sqrt(rm); };
break;
}

Expand Down Expand Up @@ -1383,17 +1383,16 @@ void FpTernaryOp::print(ostream &os) const {
}

StateValue FpTernaryOp::toSMT(State &s) const {
function<expr(const expr&, const expr&, const expr&, FpRoundingMode)> fn;
function<expr(const expr&, const expr&, const expr&, const expr&)> fn;

switch (op) {
case FMA:
fn = [](const expr &a, const expr &b, const expr &c, FpRoundingMode rm) {
return expr::fma(a, b, c, rm.toSMT());
fn = [](const expr &a, const expr &b, const expr &c, const expr &rm) {
return expr::fma(a, b, c, rm);
};
break;
case MulAdd:
fn = [&](const expr &a, const expr &b, const expr &c, FpRoundingMode rm0) {
auto rm = rm0.toSMT();
fn = [&](const expr &a, const expr &b, const expr &c, const expr &rm) {
expr var = s.getFreshNondetVar("nondet", expr(false));
return expr::mkIf(var, expr::fma(a, b, c, rm), a.fmul(b, rm).fadd(c, rm));
};
Expand Down Expand Up @@ -1682,19 +1681,19 @@ void FpConversionOp::print(ostream &os) const {

StateValue FpConversionOp::toSMT(State &s) const {
auto &v = s[*val];
function<StateValue(const expr &, const Type &, FpRoundingMode)> fn;
function<StateValue(const expr &, const Type &, const expr&)> fn;

switch (op) {
case SIntToFP:
fn = [](auto &val, auto &to_type, auto rm) -> StateValue {
return { val.sint2fp(to_type.getAsFloatType()->getDummyFloat(),
rm.toSMT()), true };
return { val.sint2fp(to_type.getAsFloatType()->getDummyFloat(), rm),
true };
};
break;
case UIntToFP:
fn = [](auto &val, auto &to_type, auto rm) -> StateValue {
return { val.uint2fp(to_type.getAsFloatType()->getDummyFloat(),
rm.toSMT()), true };
return { val.uint2fp(to_type.getAsFloatType()->getDummyFloat(), rm),
true };
};
break;
case FPToSInt:
Expand All @@ -1709,7 +1708,7 @@ StateValue FpConversionOp::toSMT(State &s) const {
is_poison = true;
break;
case LRInt:
rm = rm_in.toSMT();
rm = rm_in;
break;
case LRound:
rm = expr::rna();
Expand Down Expand Up @@ -1746,8 +1745,8 @@ StateValue FpConversionOp::toSMT(State &s) const {
case FPExt:
case FPTrunc:
fn = [](auto &val, auto &to_type, auto rm) -> StateValue {
return { val.float2Float(to_type.getAsFloatType()->getDummyFloat(),
rm.toSMT()), true };
return { val.float2Float(to_type.getAsFloatType()->getDummyFloat(), rm),
true };
};
break;
}
Expand All @@ -1761,13 +1760,13 @@ StateValue FpConversionOp::toSMT(State &s) const {
val = ty->getFloat(val);
}

function<StateValue(FpRoundingMode)> fn_rm
function<StateValue(const expr&)> fn_rm
= [&](auto rm) { return fn(val, to_type, rm); };
AndExpr np;
np.add(sv.non_poison);

StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm)
: fn(val, to_type, rm);
: fn(val, to_type, rm.toSMT());
np.add(std::move(ret.non_poison));

return { to_type.isFloatType()
Expand Down

0 comments on commit 91071b8

Please # to comment.