diff --git a/conf/SystemDS-config-defaults.xml b/conf/SystemDS-config-defaults.xml
index f48912c1f4d..d285c55c5f1 100644
--- a/conf/SystemDS-config-defaults.xml
+++ b/conf/SystemDS-config-defaults.xml
@@ -18,4 +18,5 @@
-->
+ 2
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 32ce3580868..c4d774b975d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -86,6 +86,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
+ if ( DMLScript.APPLY_GENERATED_REWRITES ) {
+ _dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass()));
+ }
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again)
diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java
index 32cac368e93..9955119dbe2 100644
--- a/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java
+++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriteAutomaticallyGenerated.java
@@ -17,6 +17,7 @@ public class RewriteAutomaticallyGenerated extends HopRewriteRule {
public static final String FILE_PATH = "/Users/janniklindemann/Dev/MScThesis/rules.rl";
public static final String VALIDATED_FILE_PATH = "/Users/janniklindemann/Dev/MScThesis/rules_validated.rl";
public static final String RAW_FILE_PATH = "/Users/janniklindemann/Dev/MScThesis/raw_rules.rl";
+ public static final String FILE_PATH_MB = "/Users/janniklindemann/Dev/MScThesis/rules_mb.rl";
public static RewriteAutomaticallyGenerated existingRewrites;
private Function rewriteFn;
diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java
index b8f41ffa22e..7d767cf4197 100644
--- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java
+++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java
@@ -80,6 +80,8 @@ public boolean determineConditionalApplicability() {
if (!requireCostCheck)
return false;
+ List roots = toRoots == null ? List.of(toRoot) : toRoots;
+
boolean integrateSparsityInCosts = isConditionalMultiRule() || RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, false, 20);
MutableObject assertionRef = new MutableObject<>(assertions);
@@ -87,7 +89,7 @@ public boolean determineConditionalApplicability() {
toCosts = getStmt2AsList().stream().map(root -> RewriterCostEstimator.getRawCostFunction(root, ctx, assertionRef, !integrateSparsityInCosts)).collect(Collectors.toList());
fromCost = RewriterSparsityEstimator.rollupSparsities(fromCost, RewriterSparsityEstimator.estimateAllNNZ(fromRoot, ctx), ctx);
- toCosts = IntStream.range(0, toCosts.size()).mapToObj(i -> RewriterSparsityEstimator.rollupSparsities(toCosts.get(i), RewriterSparsityEstimator.estimateAllNNZ(toRoots.get(i), ctx), ctx)).collect(Collectors.toList());
+ toCosts = IntStream.range(0, toCosts.size()).mapToObj(i -> RewriterSparsityEstimator.rollupSparsities(toCosts.get(i), RewriterSparsityEstimator.estimateAllNNZ(roots.get(i), ctx), ctx)).collect(Collectors.toList());
return requireCostCheck;
}
diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java
index 30a15131d19..6673ba2723f 100644
--- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java
+++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java
@@ -83,7 +83,6 @@ public static void testExpressionClustering() {
int pruneDataSubexrBiggerThan = 1000;
int maxCostSamples = 50;
- long startTime = System.currentTimeMillis();
AtomicLong generatedExpressions = new AtomicLong(0);
AtomicLong evaluatedExpressions = new AtomicLong(0);
AtomicLong failures = new AtomicLong(0);
diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java
index d2c914c9a3e..a8f1cd25b9b 100644
--- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java
+++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java
@@ -254,7 +254,7 @@ public void testConditional() {
@Test
public void codeGen() {
try {
- List lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH_CONDITIONAL));
+ List lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH_MB));
RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx);
RewriterRuntimeUtils.printUnknowns = false;