From 4078d3476488c6ab3ccff7ae897caa4124e05d77 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Mon, 13 Jan 2025 12:01:55 +0100 Subject: [PATCH] Some more fixes --- .../rewriter/ConstantFoldingFunctions.java | 9 +++++ .../rewriter/RewriterContextSettings.java | 1 + .../sysds/hops/rewriter/RewriterDataType.java | 3 ++ .../hops/rewriter/RewriterInstruction.java | 5 ++- .../hops/rewriter/RewriterRuleCollection.java | 31 ++++++++++++---- .../hops/rewriter/RewriterStatement.java | 2 +- .../sysds/hops/rewriter/TopologicalSort.java | 9 ++--- .../hops/rewriter/utils/RewriterUtils.java | 7 +++- .../rewrite/RewriterNormalFormTests.java | 35 ++++++++----------- 9 files changed, 65 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java b/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java index 3dc5332244b..7d335f12a0d 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/ConstantFoldingFunctions.java @@ -54,6 +54,15 @@ public static boolean isNeutralElement(Object num, String op) { return false; } + public static boolean isNegNeutral(Object num, String op) { + switch (op) { + case "*": + return num.equals(-1L) || num.equals(-1.0D); + } + + return false; + } + public static boolean cancelOutNary(String op, List stmts) { Set toRemove = new HashSet<>(); switch (op) { diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java index c6b6ad675b1..7ca085841fa 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java @@ -263,6 +263,7 @@ public static String getDefaultContextString() { builder.append("rand(INT,INT,FLOAT,FLOAT)::MATRIX\n"); // Args: rows, cols, min, max builder.append("rand(INT,INT)::FLOAT\n"); // Just to make it possible to say that random is dependent on both matrix indices + builder.append("rand(INT...)::FLOAT\n"); builder.append("matrix(INT,INT,INT)::MATRIX\n"); builder.append("trace(MATRIX)::FLOAT\n"); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java index 8d592290c5c..f75b4dea04d 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java @@ -339,6 +339,9 @@ public boolean match(final MatcherContext mCtx) { return true; } + if (mCtx.isDebug()) + System.out.println("MismatchAssoc: " + stmt + " <=> " + assoc); + mCtx.setFirstMismatch(this, stmt); return false; } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java index 5b994e5eae8..f2d2b709c1d 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java @@ -265,8 +265,11 @@ else if (mismatchCtr > 1) for (int i = 0; i < s; i++) { mCtx.currentStatement = inst.operands.get(i); - if (!operands.get(i).match(mCtx)) + if (!operands.get(i).match(mCtx)) { + if (mCtx.isDebug()) + System.out.println("Mismatch: " + operands.get(i) + " <=> " + inst.operands.get(i)); return false; + } } mCtx.getInternalReferences().put(this, stmt); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java index aedf5e633df..323d51d0e14 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java @@ -758,6 +758,17 @@ public static void canonicalizeBooleanStatements(final List rules, public static void expandStreamingExpressions(final List rules, final RuleContext ctx) { HashMap hooks = new HashMap<>(); + // cast.MATRIX + rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a") + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cast.MATRIX(a)", hooks) + .toParsedStatement("$4:_m(1, 1, a)", hooks) + .build() + ); + // Const rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix") .setUnidirectional(true) @@ -970,7 +981,7 @@ public static void expandStreamingExpressions(final List rules, fi .parseGlobalVars("INT:n,m") .parseGlobalVars("FLOAT:a,b") .withParsedStatement("rand(n, m, a, b)", hooks) - .toParsedStatement("$3:_m($1:_idx(1, n), $2:_idx(1, m), +(a, *(+(b, -(a)), rand($1, $2))))", hooks) + .toParsedStatement("$3:_m($1:_idx(1, n), $2:_idx(1, m), +(a, *(+(b, -(a)), rand(argList($1,$2)))))", hooks) .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) .apply(hooks.get(3).getId(), stmt -> { @@ -1406,9 +1417,7 @@ public static void pushdownStreamSelections(final List rules, fina .build() ); - - // TODO: Deal with boolean or int matrices - rules.add(new RewriterRuleBuilder(ctx, "_m(i::, j::, v) => cast.MATRIX(v)") + /*rules.add(new RewriterRuleBuilder(ctx, "_m(i::, j::, v) => cast.MATRIX(v)") .setUnidirectional(true) .parseGlobalVars("MATRIX:A,B") .parseGlobalVars("INT:i,j") @@ -1424,7 +1433,7 @@ public static void pushdownStreamSelections(final List rules, fina return matching; }, true) .build() - ); + );*/ rules.add(new RewriterRuleBuilder(ctx, "_idx(a,a) => a") .setUnidirectional(true) @@ -1493,7 +1502,6 @@ public static void pushdownStreamSelections(final List rules, fina ); RewriterUtils.buildBinaryPermutations(List.of("FLOAT"), (t1, t2) -> { - // TODO: This probably first requires pulling out invariants of this idxExpr rules.add(new RewriterRuleBuilder(ctx, "*(sum(_idxExpr(i, ...)), sum(_idxExpr(j, ...))) => _idxExpr(i, _idxExpr(j, sum(*(...)))") .setUnidirectional(true) .parseGlobalVars("MATRIX:A,B") @@ -1722,6 +1730,17 @@ public static void buildElementWiseAlgebraicCanonicalization(final List { + rules.add(new RewriterRuleBuilder(ctx, "-(a) => *(-1.0, a)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_" + t + ":-1") + .withParsedStatement("-(a)") + .toParsedStatement("*(-1, a)") + .build() + ); + }); } @Deprecated diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java index b57480f442a..21ad0235637 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java @@ -347,7 +347,7 @@ public MatcherContext debug(boolean debug) { } public boolean match() { - return matchRoot.match(this); + return thisExpressionRoot.match(this); } public boolean isDebug() { diff --git a/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java b/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java index 064a0b15b6d..77ebb7b9e53 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java @@ -14,7 +14,7 @@ // For now, we assume that _argList() will have one unique parent public class TopologicalSort { public static boolean DEBUG = false; - private static final Set SORTABLE_ARGLIST_OPS = Set.of("+", "-", "*", "_idxExpr", "_EClass"); + private static final Set SORTABLE_ARGLIST_OPS = Set.of("+", "-", "*", "_idxExpr", "_EClass", "rand"); private static final Set SORTABLE_OPS = Set.of("==", "!="); // TODO: Sort doesn't work if we have sth like _EClass(argList(nrow(U), nrow(V)), as the lowest address will be nrow, ncol and not U, V @@ -30,14 +30,12 @@ public static void sort(RewriterStatement root, final RuleContext ctx) { }, ctx); } + // TODO: Fails for E_Classes in DataTypes (matrix) if they do not occur elsewhere public static void sort(RewriterStatement root, BiFunction isArrangable, final RuleContext ctx) { List uncertainParents = setupOrderFacts(root, isArrangable, ctx); - //Set lowestUncertainties = findLowestUncertainties(root); - //setupAddresses(lowestUncertainties); buildAddresses(root, ctx); resolveAmbiguities(root, ctx, uncertainParents); - // TODO: Propagate address priorities and thus implicit orderings up the DAG resetAddresses(uncertainParents); int factCtr = 0; @@ -52,8 +50,6 @@ public static void sort(RewriterStatement root, BiFunction