diff --git a/src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java b/src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java index e8f4932eeda..42f9e31ca15 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java @@ -9,18 +9,25 @@ public class DMLExecutor { private static PrintStream origPrintStream = System.out; - public static synchronized void executeCode(String code) { - executeCode(code, s -> {}); + public static synchronized void executeCode(String code, boolean intercept, String... additionalArgs) { + executeCode(code, intercept ? s -> {} : null, additionalArgs); } // TODO: We will probably need some kind of watchdog // This cannot run in parallel - public static synchronized void executeCode(String code, Consumer consoleInterceptor) { + public static synchronized void executeCode(String code, Consumer consoleInterceptor, String... additionalArgs) { try { if (consoleInterceptor != null) System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor))); - DMLScript.executeScript(new String[]{"-s", code}); + String[] args = new String[additionalArgs.length + 2]; + + for (int i = 0; i < additionalArgs.length; i++) + args[i] = additionalArgs[i]; + + args[additionalArgs.length] = "-s"; + args[additionalArgs.length + 1] = code; + DMLScript.executeScript(args); } catch (Exception e) { e.printStackTrace(); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java index 11264c44207..444b57b43bf 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java @@ -2,6 +2,7 @@ import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.rewrite.HopRewriteRule; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.hops.rewrite.ProgramRewriteStatus; import java.io.IOException; @@ -74,6 +75,9 @@ private void rule_apply(Hop hop, boolean descendFirst) if(hop.isVisited()) return; + //DMLExecutor.println("Hop: " + hop + ", " + hop.getName() + ": " + HopRewriteUtils.isSparse(hop)); + //DMLExecutor.println("NNZ: " + hop.getNnz()); + //System.out.println("Stepping into: " + hop); //recursively process children diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterCostEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterCostEstimator.java index ee7a5f1eed5..43904e8cf3e 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterCostEstimator.java @@ -21,6 +21,7 @@ public static long estimateCost(RewriterStatement stmt, final RuleContext ctx) { public static long estimateCost(RewriterStatement stmt, Function propertyGenerator, final RuleContext ctx) { RewriterAssertions assertions = new RewriterAssertions(ctx); RewriterStatement costFn = propagateCostFunction(stmt, ctx, assertions); + System.out.println(costFn.toParsableString(ctx)); Map map = new HashMap<>(); @@ -42,6 +43,7 @@ public static long estimateCost(RewriterStatement stmt, Function rules, fina for (int idx = 0; idx < 2; idx++) { RewriterStatement oldRef = lnk.oldStmt.getOperands().get(idx); RewriterStatement newRef = lnk.newStmt.get(0).getChild(idx); - RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, newRef.getChild(1, 1, 0)).consolidate(ctx); + System.out.println("NewRef: " + newRef.toParsableString(ctx)); + RewriterStatement mStmtC = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef.getChild(1, 1, 0), RewriterStatement.literal(ctx, -1L)).consolidate(ctx); + RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, mStmtC).consolidate(ctx); + System.out.println("mStmt: " + mStmt.toParsableString(ctx)); final RewriterStatement newStmt = RewriterUtils.foldConstants(mStmt, ctx); // Replace all references to h with @@ -1049,8 +1052,10 @@ public static void pushdownStreamSelections(final List rules, fina RewriterStatement child = el.getOperands().get(i); Object meta = child.getMeta("idxId"); - if (meta instanceof UUID && meta.equals(oldRef.getMeta("idxId"))) + if (meta instanceof UUID && meta.equals(oldRef.getMeta("idxId"))) { + System.out.println("NewStmt: " + newStmt.toParsableString(ctx)); el.getOperands().set(i, newStmt); + } } }, false); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java index 4941525be55..12c58eb9f47 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java @@ -216,7 +216,7 @@ public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final R return false; // The program should not be executed as we just want to extract any rewrites that are applied to the current statement }); - DMLExecutor.executeCode(code2); + DMLExecutor.executeCode(code2, true); RewriterRuntimeUtils.detachHopInterceptor(); return isValid.booleanValue() && isRelevant.booleanValue(); diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeTest.java new file mode 100644 index 00000000000..1f70decac48 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeTest.java @@ -0,0 +1,17 @@ +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.hops.rewriter.DMLExecutor; +import org.junit.Test; + +public class CodeTest { + @Test + public void test() { + String str = "X = rand(rows=5000, cols=5000, sparsity=0.1)\n" + + "Y = rand(rows=5000, cols=5000, sparsity=0.1)\n" + + "R = X*Y\n" + + "print(lineage(R))"; + DMLScript.APPLY_GENERATED_REWRITES = true; + DMLExecutor.executeCode(str, false, "-applyGeneratedRewrites"); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java index 1ce3836541e..81b6004df84 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java @@ -16,7 +16,7 @@ public class CostEstimates { @BeforeClass public static void setup() { ctx = RewriterUtils.buildDefaultContext(); - canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true); } @Test @@ -217,4 +217,55 @@ public void test11() { System.out.println(stmt2.toParsableString(ctx, true)); assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2)); } + + @Test + public void test12() { + String stmtStr1 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "+([](A, 1, nrow(A), 1, 1),B)"; + String stmtStr2 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "+([](A, 1, nrow(A), 1, ncol(A)), B)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + System.out.println("Cost1: " + cost1); + System.out.println("Cost2: " + cost2); + System.out.println("Ratio: " + ((double)cost1)/cost2); + + assert cost1 < cost2; + } + + @Test + public void test13() { + String stmtStr1 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "[](rowSums(A), 1, nrow(A), 1, 1)"; + String stmtStr2 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "rowSums(A)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + System.out.println("Cost1: " + cost1); + System.out.println("Cost2: " + cost2); + System.out.println("Ratio: " + ((double)cost1)/cost2); + + assert cost2 < cost1; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + System.out.println("=========="); + System.out.println(stmt1.toParsableString(ctx, true)); + System.out.println("=========="); + System.out.println(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2)); + } }