Skip to content

Commit

Permalink
Some more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent c57aa33 commit ac83c66
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import java.util.function.BiFunction;

public class ConstantFoldingFunctions {
static final double EPS = 1e-10;

public static BiFunction<Number, RewriterStatement, Number> foldingBiFunction(String op, String type) {
switch (op) {
case "+":
Expand Down Expand Up @@ -36,6 +38,15 @@ public static boolean isNeutralElement(Object num, String op) {
return false;
}

// TODO: What about NaNs?
public static RewriterStatement overwritesLiteral(Number num, String op, final RuleContext ctx) {
if (op.equals("*") && Math.abs(num.doubleValue()) < EPS) {
return RewriterStatement.literal(ctx, num);
}

return null;
}

public static double foldSumFloat(double num, RewriterStatement next) {
return num + next.floatLiteral();
}
Expand Down
12 changes: 10 additions & 2 deletions src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,6 @@ public static Function<RewriterStatement, RewriterStatement> buildCanonicalFormC
}, debug);

RewriterUtils.mergeArgLists(stmt, ctx);
stmt = foldConstants(stmt, ctx);
stmt = afterFlattening.apply(stmt, (t, r) -> {
if (!debug)
return true;
Expand Down Expand Up @@ -1352,6 +1351,10 @@ private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final
if (argList.size() == 1)
return argList.get(0);
}

RewriterStatement overwrite = ConstantFoldingFunctions.overwritesLiteral((Number)argList.get(literals[0]).getLiteral(), stmt.trueInstruction(), ctx);
if (overwrite != null)
return overwrite;
}

if (literals.length < 2)
Expand All @@ -1367,11 +1370,16 @@ private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final
for (int literal : literals)
val = foldingFunction.apply(val, argList.get(literal));


RewriterStatement overwrite = ConstantFoldingFunctions.overwritesLiteral(val, stmt.trueInstruction(), ctx);
if (overwrite != null)
return overwrite;

foldedLiteral.as(val.toString()).ofType(rType).asLiteral(val).consolidate(ctx);

argList.removeIf(RewriterStatement::isLiteral);

if (!ConstantFoldingFunctions.isNeutralElement(foldedLiteral.getLiteral(), stmt.trueInstruction()))
if (argList.isEmpty() || !ConstantFoldingFunctions.isNeutralElement(foldedLiteral.getLiteral(), stmt.trueInstruction()))
argList.add(foldedLiteral);

if (argList.size() == 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,4 +706,20 @@ public void testConstantFolding3() {

assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void testConstantFolding4() {
RewriterStatement stmt1 = RewriterUtils.parse("*(A, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0");
RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), 0, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0");

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));

assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}
}

0 comments on commit ac83c66

Please # to comment.