From 636ea9d4a99e2f8c50ea65162107c2d6bbb3fd0f Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Wed, 23 Oct 2024 13:11:49 +0200 Subject: [PATCH] Some fixes --- .../hops/rewriter/RewriterRuleCollection.java | 56 +++++++++++++++++++ .../codegen/rewrite/RewriterStreamTests.java | 38 ++++++------- 2 files changed, 75 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java index afccc276399..554f509c4a8 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java @@ -1067,6 +1067,62 @@ public static void pushdownStreamSelections(final List rules, fina .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) .build() ); + + for (String t : ALL_TYPES) { + if (t.equals("MATRIX")) { + rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, v), b) => _m(i, j, ElementWiseInstruction(v, b))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:v") + .parseGlobalVars(t + ":B") + .parseGlobalVars("INT:i,j") + .withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), B)", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, [](B, i, j)))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(3).getId(), (stmt, match) -> { + // Then we an infer that the two matrices have the same dimensions + match.getExpressionRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNCol(), stmt.getChild(2, 1, 0).getNCol()); + match.getExpressionRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNRow(), stmt.getChild(2, 1, 0).getNRow()); + }, true) + .build() + ); + + continue; + } + rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, A), b) => _m(i, j, ElementWiseInstruction(A, b))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:v") + .parseGlobalVars(t + ":b") + .parseGlobalVars("INT:i,j") + .withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), b)", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, b))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(A, v), i, j) => ElementWiseInstruction(v, [](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":v") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseInstruction(A, v), i, j)", hooks) + .toParsedStatement("$2:ElementWiseInstruction([](A, i, j), v)", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(v, A), i, j) => ElementWiseInstruction(v, [](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":v") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseInstruction(v, A), i, j)", hooks) + .toParsedStatement("$2:ElementWiseInstruction(v, [](A, i, j))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + } } public static void streamifyExpressions(final List rules, final RuleContext ctx) { diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java index 3a6ae215861..20203564ed2 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java @@ -11,20 +11,20 @@ public class RewriterStreamTests { private static RuleContext ctx; - private static Function converter; + //private static Function converter; private static Function canonicalConverter; @BeforeClass public static void setup() { ctx = RewriterUtils.buildDefaultContext(); - converter = RewriterUtils.buildFusedOperatorCreator(ctx, true); + //converter = RewriterUtils.buildFusedOperatorCreator(ctx, true); canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true); } @Test public void testAdditionFloat1() { RewriterStatement stmt = RewriterUtils.parse("+(+(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1"); - stmt = converter.apply(stmt); + stmt = canonicalConverter.apply(stmt); System.out.println(stmt.toParsableString(ctx, true)); assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"))); } @@ -32,31 +32,30 @@ public void testAdditionFloat1() { @Test public void testAdditionFloat2() { RewriterStatement stmt = RewriterUtils.parse("+(1, +(a, b))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"); - stmt = converter.apply(stmt); + stmt = canonicalConverter.apply(stmt); System.out.println(stmt.toParsableString(ctx, true)); assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"))); } @Test public void testAdditionMatrix1() { - RewriterStatement stmt = RewriterUtils.parse("+(+(A, B), 1)", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); - stmt = converter.apply(stmt); - System.out.println(stmt.toParsableString(ctx, true)); - assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), +(argList([](B, $1, $2), [](A, $1, $2), 1)))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"))); - } + RewriterStatement stmt1 = RewriterUtils.parse("+(+(A, B), 1)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(+(B, 1), A)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); - @Test - public void testAdditionMatrix2() { - RewriterStatement stmt = RewriterUtils.parse("+(1, +(A, B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); - stmt = converter.apply(stmt); - System.out.println(stmt.toParsableString(ctx, true)); - assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), +(argList([](B, $1, $2), [](A, $1, $2), 1)))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"))); + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + System.out.println("=========="); + System.out.println(stmt1.toParsableString(ctx, true)); + System.out.println("=========="); + System.out.println(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2)); } @Test public void testSubtractionFloat1() { RewriterStatement stmt = RewriterUtils.parse("+(-(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1"); - stmt = converter.apply(stmt); + stmt = canonicalConverter.apply(stmt); System.out.println(stmt.toParsableString(ctx, true)); assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(-(b), a, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"))); } @@ -64,12 +63,13 @@ public void testSubtractionFloat1() { @Test public void testSubtractionFloat2() { RewriterStatement stmt = RewriterUtils.parse("+(1, -(a, -(b, c)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b,c", "LITERAL_INT:0,1"); - stmt = converter.apply(stmt); + stmt = canonicalConverter.apply(stmt); System.out.println(stmt.toParsableString(ctx, true)); assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(-(b), a, c, 1))", ctx, "FLOAT:a,b, c", "LITERAL_INT:0,1"))); } - @Test + // Fusion will no longer be pursued + /*@Test public void testFusedPlanMatrixGeneration() { RewriterStatement stmt = RewriterUtils.parse("+(1, +(A, B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); stmt = converter.apply(stmt); @@ -94,7 +94,7 @@ public void testFusedPlanAdvancedAggregationGeneration() { RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx); System.out.println("Orig: " + stmt.toParsableString(ctx, true)); System.out.println("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true))); - } + }*/ @Test public void testReorgEquivalence() {