diff --git a/conf/SystemDS-config-defaults.xml b/conf/SystemDS-config-defaults.xml index f48912c1f4d..d285c55c5f1 100644 --- a/conf/SystemDS-config-defaults.xml +++ b/conf/SystemDS-config-defaults.xml @@ -18,4 +18,5 @@ --> + 2 diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index 32ce3580868..c4d774b975d 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -86,6 +86,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse + if ( DMLScript.APPLY_GENERATED_REWRITES ) { + _dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass())); + } if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java index 32cac368e93..9955119dbe2 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java @@ -17,6 +17,7 @@ public class RewriteAutomaticallyGenerated extends HopRewriteRule { public static final String FILE_PATH = "/Users/janniklindemann/Dev/MScThesis/rules.rl"; public static final String VALIDATED_FILE_PATH = "/Users/janniklindemann/Dev/MScThesis/rules_validated.rl"; public static final String RAW_FILE_PATH = "/Users/janniklindemann/Dev/MScThesis/raw_rules.rl"; + public static final String FILE_PATH_MB = "/Users/janniklindemann/Dev/MScThesis/rules_mb.rl"; public static RewriteAutomaticallyGenerated existingRewrites; private Function rewriteFn; diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java index b8f41ffa22e..7d767cf4197 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java @@ -80,6 +80,8 @@ public boolean determineConditionalApplicability() { if (!requireCostCheck) return false; + List roots = toRoots == null ? List.of(toRoot) : toRoots; + boolean integrateSparsityInCosts = isConditionalMultiRule() || RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, false, 20); MutableObject assertionRef = new MutableObject<>(assertions); @@ -87,7 +89,7 @@ public boolean determineConditionalApplicability() { toCosts = getStmt2AsList().stream().map(root -> RewriterCostEstimator.getRawCostFunction(root, ctx, assertionRef, !integrateSparsityInCosts)).collect(Collectors.toList()); fromCost = RewriterSparsityEstimator.rollupSparsities(fromCost, RewriterSparsityEstimator.estimateAllNNZ(fromRoot, ctx), ctx); - toCosts = IntStream.range(0, toCosts.size()).mapToObj(i -> RewriterSparsityEstimator.rollupSparsities(toCosts.get(i), RewriterSparsityEstimator.estimateAllNNZ(toRoots.get(i), ctx), ctx)).collect(Collectors.toList()); + toCosts = IntStream.range(0, toCosts.size()).mapToObj(i -> RewriterSparsityEstimator.rollupSparsities(toCosts.get(i), RewriterSparsityEstimator.estimateAllNNZ(roots.get(i), ctx), ctx)).collect(Collectors.toList()); return requireCostCheck; } diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java index 30a15131d19..6673ba2723f 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java @@ -83,7 +83,6 @@ public static void testExpressionClustering() { int pruneDataSubexrBiggerThan = 1000; int maxCostSamples = 50; - long startTime = System.currentTimeMillis(); AtomicLong generatedExpressions = new AtomicLong(0); AtomicLong evaluatedExpressions = new AtomicLong(0); AtomicLong failures = new AtomicLong(0); diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java index d2c914c9a3e..a8f1cd25b9b 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java @@ -254,7 +254,7 @@ public void testConditional() { @Test public void codeGen() { try { - List lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH_CONDITIONAL)); + List lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH_MB)); RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); RewriterRuntimeUtils.printUnknowns = false;