From a012b3279d72130e22ff82e8d3ece4730eb3b36f Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Mon, 9 Dec 2024 17:34:21 +0100 Subject: [PATCH] Bugfix --- .../rewriter/ConstantFoldingFunctions.java | 12 +-- .../hops/rewriter/GeneratedRewriteClass.java | 5 + .../sysds/hops/rewriter/RewriterDataType.java | 10 +- .../hops/rewriter/RewriterInstruction.java | 2 +- .../sysds/hops/rewriter/RewriterRuleSet.java | 19 ++-- .../hops/rewriter/RewriterStatement.java | 5 +- .../rewriter/codegen/CodeGenCondition.java | 94 ++++++++++++++++++- .../hops/rewriter/codegen/CodeGenUtils.java | 2 + .../rewriter/codegen/RewriterCodeGen.java | 43 ++++++++- .../estimators/RewriterCostEstimator.java | 2 +- .../rewrite/functions/CodeGenTests.java | 9 +- 11 files changed, 174 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java b/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java index 5d908341cc1..3dc5332244b 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java @@ -30,13 +30,13 @@ else if (type.equals("INT")) if (type.equals("FLOAT")) return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMinFloat((double)num, stmt); else if (type.equals("INT")) - return (num, stmt) -> num == null ? stmt.intLiteral() : foldMinInt((long)num, stmt); + return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMinInt((long)num, stmt); break; case "max": if (type.equals("FLOAT")) return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMaxFloat((double)num, stmt); else if (type.equals("INT")) - return (num, stmt) -> num == null ? stmt.intLiteral() : foldMaxInt((long)num, stmt); + return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMaxInt((long)num, stmt); break; } @@ -122,7 +122,7 @@ public static double foldSumFloat(double num, RewriterStatement next) { } public static long foldSumInt(long num, RewriterStatement next) { - return num + next.intLiteral(); + return num + next.intLiteral(false); } public static double foldMulFloat(double num, RewriterStatement next) { @@ -130,7 +130,7 @@ public static double foldMulFloat(double num, RewriterStatement next) { } public static long foldMulInt(long num, RewriterStatement next) { - return num * next.intLiteral(); + return num * next.intLiteral(false); } public static double foldMinFloat(double num, RewriterStatement next) { @@ -138,7 +138,7 @@ public static double foldMinFloat(double num, RewriterStatement next) { } public static long foldMinInt(long num, RewriterStatement next) { - return Math.min(num, next.intLiteral()); + return Math.min(num, next.intLiteral(false)); } public static double foldMaxFloat(double num, RewriterStatement next) { @@ -146,6 +146,6 @@ public static double foldMaxFloat(double num, RewriterStatement next) { } public static long foldMaxInt(long num, RewriterStatement next) { - return Math.max(num, next.intLiteral()); + return Math.max(num, next.intLiteral(false)); } } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/GeneratedRewriteClass.java b/src/main/java/org/apache/sysds/hops/rewriter/GeneratedRewriteClass.java index f0aa51552c8..9ab6d9ff663 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/GeneratedRewriteClass.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/GeneratedRewriteClass.java @@ -180,10 +180,13 @@ public Object apply( Object _hi ) { hi = _applyRewrite89(hi); // +(-(0.0,b),A) => -(A,b) } else { if ( hi_1.getDataType() == Types.DataType.MATRIX ) { + System.out.println("HERE0"); if ( hi_1 instanceof UnaryOp ) { + System.out.println("a"); hi = _applyRewrite30(hi); // +(a,cast.MATRIX(0.0)) => cast.MATRIX(a) hi = _applyRewrite72(hi); // +(a,cast.MATRIX(b)) => cast.MATRIX(+(a,b)) } else if ( hi_1 instanceof BinaryOp ) { + System.out.println("b"); if ( (( BinaryOp ) hi_1 ).getOp() == Types.OpOp2.MINUS ) { if ( hi_1.getInput().size() == 2 ) { Hop hi_1_0 = hi_1.getInput(0); @@ -214,6 +217,7 @@ public Object apply( Object _hi ) { } } } else if ( hi_1 instanceof ReorgOp ) { + System.out.println("c"); if ( (( ReorgOp ) hi_1 ).getOp() == Types.ReOrgOp.REV ) { hi = _applyRewrite283(hi); // +(a,rev($1:-(b,C))) => -(+(a,b),rev(C)) hi = _applyRewrite288(hi); // +(a,rev($1:-(C,b))) => +(-(a,b),rev(C)) @@ -226,6 +230,7 @@ public Object apply( Object _hi ) { hi = _applyRewrite356(hi); // +(a,t($1:+(C,b))) => +(+(a,b),t(C)) } } else { + System.out.println("HERE1"); hi = _applyRewrite5(hi); // +(0.0,A) => A } } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java index ffc85ca1cd9..5d45ca82fdc 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java @@ -92,9 +92,15 @@ public Object getLiteral() { } @Override - public long intLiteral() { + public long intLiteral(boolean cast) { if (getLiteral() instanceof Boolean) return (boolean)getLiteral() ? 1 : 0; + + if (cast && getLiteral() instanceof Double) { + double val = floatLiteral(); + return (long)val; + } + return (long)getLiteral(); } @@ -113,7 +119,7 @@ public boolean boolLiteral() { return (boolean)getLiteral(); if (getLiteral() instanceof Long) return (long)getLiteral() == 0L; - return (double)getLiteral() == 0.0; + return (double)getLiteral() == 0.0D; } @Override diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java index 1f889afbab1..9cf9704a965 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java @@ -111,7 +111,7 @@ public RewriterStatement getLiteralStatement() { } @Override - public long intLiteral() { + public long intLiteral(boolean cast) { throw new UnsupportedOperationException(); } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java index c9f9435611b..e6acb54cc40 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -274,27 +275,29 @@ public String serialize(final RuleContext ctx) { return sb.toString(); } - public boolean generateCodeAndTest(boolean optimize, boolean print) { + public Set generateCodeAndTest(boolean optimize, boolean print) { String javaCode = toJavaCode("MGeneratedRewriteClass", optimize, false, true); Function f = RewriterCodeGen.compile(javaCode, "MGeneratedRewriteClass"); if (f == null) - return false; // Then, the code could not compile + return null; // Then, the code could not compile - int origSize = rules.size(); + //int origSize = rules.size(); + Set removed = new HashSet<>(); for (int i = 0; i < rules.size(); i++) { if (!RewriterRuleCreator.validateRuleApplicability(rules.get(i), ctx, print, f)) { System.out.println("Faulty rule: " + rules.get(i)); - rules.remove(i); - i--; + removed.add(rules.get(i)); + //rules.remove(i); + //i--; } } - if (rules.size() != origSize) - accelerate(); + //if (rules.size() != origSize) + // accelerate(); - return true; + return removed; } public static RewriterRuleSet deserialize(String data, final RuleContext ctx) { diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java index b58326ac8c1..4f05975dbba 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java @@ -479,7 +479,10 @@ public void replace(RewriterStatement newStmt) { public abstract boolean isLiteral(); public abstract Object getLiteral(); public abstract RewriterStatement getLiteralStatement(); - public abstract long intLiteral(); + public long intLiteral() { + return intLiteral(false); + } + public abstract long intLiteral(boolean cast); public abstract double floatLiteral(); public abstract boolean boolLiteral(); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java index 01e72176841..0aeebe5f578 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java @@ -1,6 +1,7 @@ package org.apache.sysds.hops.rewriter.codegen; import javassist.compiler.CodeGen; +import org.apache.sysds.hops.rewriter.RewriterDataType; import org.apache.sysds.hops.rewriter.RewriterRule; import org.apache.sysds.hops.rewriter.RewriterStatement; import org.apache.sysds.hops.rewriter.RuleContext; @@ -28,6 +29,7 @@ public enum ConditionDataType { private ConditionType conditionType; private Object conditionValue; private List rulesIf; + private List applyAnyway; private List relativeChildPath; private RewriterStatement representant; @@ -35,6 +37,7 @@ private CodeGenCondition(ConditionType cType, Object cValue, List relat conditionType = cType; conditionValue = cValue; rulesIf = new ArrayList<>(); + applyAnyway = new ArrayList<>(); this.relativeChildPath = relativeChildPath; this.representant = representant; @@ -85,6 +88,11 @@ private static List populateLayerRecursively(List rules, List) c4.rulesIf.get(0))._2 == null) { + continue; // TODO: Is that correct? + } + final int maxIndex = ((Tuple2) c4.rulesIf.get(0))._2.getOperands().size(); Set activeRules = c4.rulesIf.stream().map(o -> ((Tuple2) o)._1).collect(Collectors.toSet()); Queue, List>> mQueue = new LinkedList<>(); @@ -109,7 +117,7 @@ private static List populateLayerRecursively(List rules, List { Tuple2 t = (Tuple2) o; - mList.add(new Tuple2(t._1, t._2.getChild(mIdx))); + mList.add(new Tuple2(t._1, (t._2 == null ? null : (t._2.getOperands().isEmpty() ? null : t._2.getChild(mIdx))))); }); } @@ -135,16 +143,23 @@ private static List populateLayerRecursively(List rules, List rules, List generatedConditions) { int origSize = rules.size(); int newSize = generatedConditions.stream().mapToInt(o -> ((CodeGenCondition)o).rulesIf.size()).sum(); - return origSize == newSize; + return origSize <= newSize; } private static List populateDataTypeLayer(List rules, List relativeChildPath, final RuleContext ctx) { List conds = new ArrayList<>(); + List> defer = new ArrayList<>(); //System.out.println("====="); for (Object o : rules) { Tuple2 t = (Tuple2) o; + + if (t._2 == null) { + defer.add(t); + continue; + } + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { CodeGenCondition cond = CodeGenCondition.conditionalDataType(t._2, relativeChildPath, t._2, ctx); cond.insertIfMatches(t, ctx); @@ -162,6 +177,21 @@ private static List populateDataTypeLayer(List rules, List(defer), relativeChildPath, null, ctx)); + } + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + // Validate + /*conds.stream().map(cond -> ((CodeGenCondition)cond)).forEach(cond -> { + if (cond.rulesIf.isEmpty()) + throw new IllegalArgumentException(); + });*/ + if (!validateSizeMaintenance(rules, conds)) throw new IllegalArgumentException(); @@ -171,10 +201,17 @@ private static List populateDataTypeLayer(List rules, List populateOpClassLayer(List l, List relativeChildPath, final RuleContext ctx) { List conds = new ArrayList<>(); List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); for (Object o : l) { try { Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + if (canGenerateOpClassCheck(t._2, ctx)) { if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { CodeGenCondition cond = CodeGenCondition.conditionalOpClass(t._2, relativeChildPath, t._2, ctx); @@ -189,12 +226,19 @@ private static List populateOpClassLayer(List l, List r } } + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + if (!remaining.isEmpty()) { conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); } - //if (!validateSizeMaintenance(l, conds)) - // throw new IllegalArgumentException(); + if (!validateSizeMaintenance(l, conds)) + throw new IllegalArgumentException(); return conds; } @@ -202,20 +246,38 @@ private static List populateOpClassLayer(List l, List r private static List populateOpCodeLayer(List l, List relativeChildPath, final RuleContext ctx) { List conds = new ArrayList<>(); List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); for (Object o : l) { Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + if (canGenerateOpCodeCheck(t._2, ctx)) { if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { CodeGenCondition cond = CodeGenCondition.conditionalOpCode(t._2, relativeChildPath, t._2, ctx); cond.insertIfMatches(t, ctx); conds.add(cond); } + } else if (t._2 instanceof RewriterDataType && !t._2.isLiteral()) { + // Then we must add it to all conditions + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(t); } else { remaining.add(t); } } + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + if (!remaining.isEmpty()) { conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); } @@ -229,20 +291,38 @@ private static List populateOpCodeLayer(List l, List re private static List populateInputSizeLayer(List l, List relativeChildPath, final RuleContext ctx) { List conds = new ArrayList<>(); List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); for (Object o : l) { Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + if (canGenerateInputSizeCheck(t._2, ctx)) { if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { CodeGenCondition cond = CodeGenCondition.conditionalInputSize(t._2.getOperands().size(), relativeChildPath, t._2, ctx); cond.insertIfMatches(t, ctx); conds.add(cond); } + } else if (t._2 instanceof RewriterDataType && !t._2.isLiteral()) { + // Then we must add it to all conditions + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(t); } else { remaining.add(t); } } + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + if (!remaining.isEmpty()) { conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); } @@ -467,6 +547,9 @@ public static void buildSelection(StringBuilder sb, List conds if (nestedCondition.isEmpty()) { List> cur = firstCond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + if (cur.isEmpty()) + throw new IllegalArgumentException(firstCond.rulesIf.toString()); + for (Tuple2 t : cur) { String fMapping = ruleFunctionMappings.get(t._1); if (fMapping != null) { @@ -515,6 +598,9 @@ public static void buildSelection(StringBuilder sb, List conds if (mNestedCondition.isEmpty()) { List> cur = cond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + if (cur.isEmpty()) + throw new IllegalArgumentException(cond.rulesIf.toString()); + for (Tuple2 t : cur) { String fMapping = ruleFunctionMappings.get(t._1); if (fMapping != null) { diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenUtils.java index 59cd7192288..cdab382c268 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenUtils.java @@ -1,10 +1,12 @@ package org.apache.sysds.hops.rewriter.codegen; import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; import org.apache.sysds.hops.rewriter.RewriterStatement; import org.apache.sysds.hops.rewriter.RuleContext; +import javax.annotation.Nullable; import java.util.Map; import java.util.Optional; diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java index f1252097633..3143fa50ac0 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java @@ -1,7 +1,9 @@ package org.apache.sysds.hops.rewriter.codegen; +import org.apache.sysds.common.Types; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; import org.apache.sysds.hops.rewriter.RewriterDataType; @@ -115,7 +117,7 @@ public static String generateClass(String className, List t : implementedRewrites) ruleNames.put(t._2, t._1); - List conditions = CodeGenCondition.buildCondition(rules, 5, ctx); + List conditions = CodeGenCondition.buildCondition(rules, 20, ctx); CodeGenCondition.buildSelection(msb, conditions, 2, ruleNames, ctx); } else { for (Tuple2 appliedRewrites : rewrites) { @@ -369,7 +371,7 @@ private static void buildCostFnRecursively(RewriterStatement costFn, Map buildRewrite(RewriterStatement newRoot, StringBuilder sb, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation) { Set visited = new HashSet<>(); - recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited); + recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited, newRoot.getResultingDataType(ctx).equals("FLOAT")); //indent(indentation, sb); //sb.append("hi = " + vars.get(newRoot) + ";\n"); @@ -387,19 +389,50 @@ private static void removeUnreferencedHops(RewriterStatement oldRoot, Set vars, final RuleContext ctx, int indentation, int varCtr, Set visited) { + private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation, int varCtr, Set visited, boolean enforceRootDataType) { visited.add(cur); if (vars.containsKey(cur)) return varCtr; for (RewriterStatement child : cur.getOperands()) - varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited); + varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited, false); if (cur instanceof RewriterDataType) { if (cur.isLiteral()) { indent(indentation, sb); String name = "l" + (varCtr++); - sb.append("LiteralOp " + name + " = new LiteralOp( " + cur.getLiteral() + " );\n"); + String literalStr = cur.getLiteral().toString(); + + if (enforceRootDataType) { + sb.append("LiteralOp " + name + ";"); + indent(indentation, sb); + sb.append("switch (hi.getValueType()) {\n"); + indent(indentation+1, sb); + sb.append("case Types.ValueType.FP64:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.floatLiteral() + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("case Types.ValueType.INT64:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.intLiteral(true) + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("case Types.ValueType.BOOLEAN:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.boolLiteral() + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation, sb); + sb.append("}\n"); + } else { + sb.append("LiteralOp " + name + " = new LiteralOp( " + literalStr + " );\n"); + } vars.put(cur, name); } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java index 513a7ecd176..3f27d51dfde 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java @@ -434,7 +434,7 @@ public static long computeCostFunction(RewriterStatement costFn, Function(nrowLiteral.intLiteral(), ncolLiteral.intLiteral()))); + mNew = RewriterStatement.literal(ctx, nnzGenerator.apply(op, new Tuple2<>(nrowLiteral.intLiteral(false), ncolLiteral.intLiteral(false)))); map.put(op, mNew); cur.getOperands().set(i, mNew); } diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java index dbb79b3f9e8..ff0ac3a3f56 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java @@ -29,6 +29,7 @@ import java.nio.file.Paths; import java.util.HashMap; import java.util.List; +import java.util.Set; import java.util.function.Function; public class CodeGenTests { @@ -216,7 +217,13 @@ public void codeGen() { RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); RewriterRuntimeUtils.printUnknowns = false; - ruleSet.generateCodeAndTest(false, true); + Set invalid_unoptimized = ruleSet.generateCodeAndTest(false, true); + Set invalid_optimized = ruleSet.generateCodeAndTest(true, true); + System.out.println("========== DIFF ==========="); + invalid_optimized.removeAll(invalid_unoptimized); + for (RewriterRule rule : invalid_optimized) { + System.out.println(rule); + } RewriterCodeGen.DEBUG = true; String javaCode = ruleSet.toJavaCode("GeneratedRewriteClass", true, true, true);