diff --git a/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java b/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java index 1f606ee4677..36ed9e79d35 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java @@ -4,6 +4,8 @@ import java.util.function.BiFunction; public class ConstantFoldingFunctions { + static final double EPS = 1e-10; + public static BiFunction foldingBiFunction(String op, String type) { switch (op) { case "+": @@ -36,6 +38,15 @@ public static boolean isNeutralElement(Object num, String op) { return false; } + // TODO: What about NaNs? + public static RewriterStatement overwritesLiteral(Number num, String op, final RuleContext ctx) { + if (op.equals("*") && Math.abs(num.doubleValue()) < EPS) { + return RewriterStatement.literal(ctx, num); + } + + return null; + } + public static double foldSumFloat(double num, RewriterStatement next) { return num + next.floatLiteral(); } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java index af78afd1ecc..cccbbbed683 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java @@ -1248,7 +1248,6 @@ public static Function buildCanonicalFormC }, debug); RewriterUtils.mergeArgLists(stmt, ctx); - stmt = foldConstants(stmt, ctx); stmt = afterFlattening.apply(stmt, (t, r) -> { if (!debug) return true; @@ -1352,6 +1351,10 @@ private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final if (argList.size() == 1) return argList.get(0); } + + RewriterStatement overwrite = ConstantFoldingFunctions.overwritesLiteral((Number)argList.get(literals[0]).getLiteral(), stmt.trueInstruction(), ctx); + if (overwrite != null) + return overwrite; } if (literals.length < 2) @@ -1367,11 +1370,16 @@ private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final for (int literal : literals) val = foldingFunction.apply(val, argList.get(literal)); + + RewriterStatement overwrite = ConstantFoldingFunctions.overwritesLiteral(val, stmt.trueInstruction(), ctx); + if (overwrite != null) + return overwrite; + foldedLiteral.as(val.toString()).ofType(rType).asLiteral(val).consolidate(ctx); argList.removeIf(RewriterStatement::isLiteral); - if (!ConstantFoldingFunctions.isNeutralElement(foldedLiteral.getLiteral(), stmt.trueInstruction())) + if (argList.isEmpty() || !ConstantFoldingFunctions.isNeutralElement(foldedLiteral.getLiteral(), stmt.trueInstruction())) argList.add(foldedLiteral); if (argList.size() == 1) diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java index 384a4baa9e7..b56b5b5507e 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java @@ -706,4 +706,20 @@ public void testConstantFolding3() { assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2)); } + + @Test + public void testConstantFolding4() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), 0, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + System.out.println("=========="); + System.out.println(stmt1.toParsableString(ctx, true)); + System.out.println("=========="); + System.out.println(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2)); + } }