Skip to content

Commit

Permalink
Some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 0aae424 commit 636ea9d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,62 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
.link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta)
.build()
);

for (String t : ALL_TYPES) {
if (t.equals("MATRIX")) {
rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, v), b) => _m(i, j, ElementWiseInstruction(v, b))")
.setUnidirectional(true)
.parseGlobalVars("FLOAT:v")
.parseGlobalVars(t + ":B")
.parseGlobalVars("INT:i,j")
.withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), B)", hooks)
.toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, [](B, i, j)))", hooks)
.link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta)
.link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta)
.apply(hooks.get(3).getId(), (stmt, match) -> {
// Then we an infer that the two matrices have the same dimensions
match.getExpressionRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNCol(), stmt.getChild(2, 1, 0).getNCol());
match.getExpressionRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNRow(), stmt.getChild(2, 1, 0).getNRow());
}, true)
.build()
);

continue;
}
rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, A), b) => _m(i, j, ElementWiseInstruction(A, b))")
.setUnidirectional(true)
.parseGlobalVars("FLOAT:v")
.parseGlobalVars(t + ":b")
.parseGlobalVars("INT:i,j")
.withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), b)", hooks)
.toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, b))", hooks)
.link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta)
.link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta)
.build()
);

rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(A, v), i, j) => ElementWiseInstruction(v, [](A, i, j))")
.setUnidirectional(true)
.parseGlobalVars("MATRIX:A")
.parseGlobalVars(t + ":v")
.parseGlobalVars("INT:i,j")
.withParsedStatement("[]($1:ElementWiseInstruction(A, v), i, j)", hooks)
.toParsedStatement("$2:ElementWiseInstruction([](A, i, j), v)", hooks)
.link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta)
.build()
);

rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(v, A), i, j) => ElementWiseInstruction(v, [](A, i, j))")
.setUnidirectional(true)
.parseGlobalVars("MATRIX:A")
.parseGlobalVars(t + ":v")
.parseGlobalVars("INT:i,j")
.withParsedStatement("[]($1:ElementWiseInstruction(v, A), i, j)", hooks)
.toParsedStatement("$2:ElementWiseInstruction(v, [](A, i, j))", hooks)
.link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta)
.build()
);
}
}

public static void streamifyExpressions(final List<RewriterRule> rules, final RuleContext ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,65 @@
public class RewriterStreamTests {

private static RuleContext ctx;
private static Function<RewriterStatement, RewriterStatement> converter;
//private static Function<RewriterStatement, RewriterStatement> converter;
private static Function<RewriterStatement, RewriterStatement> canonicalConverter;

@BeforeClass
public static void setup() {
ctx = RewriterUtils.buildDefaultContext();
converter = RewriterUtils.buildFusedOperatorCreator(ctx, true);
//converter = RewriterUtils.buildFusedOperatorCreator(ctx, true);
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true);
}

@Test
public void testAdditionFloat1() {
RewriterStatement stmt = RewriterUtils.parse("+(+(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1");
stmt = converter.apply(stmt);
stmt = canonicalConverter.apply(stmt);
System.out.println(stmt.toParsableString(ctx, true));
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1")));
}

@Test
public void testAdditionFloat2() {
RewriterStatement stmt = RewriterUtils.parse("+(1, +(a, b))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1");
stmt = converter.apply(stmt);
stmt = canonicalConverter.apply(stmt);
System.out.println(stmt.toParsableString(ctx, true));
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1")));
}

@Test
public void testAdditionMatrix1() {
RewriterStatement stmt = RewriterUtils.parse("+(+(A, B), 1)", ctx, "MATRIX:A,B", "LITERAL_INT:0,1");
stmt = converter.apply(stmt);
System.out.println(stmt.toParsableString(ctx, true));
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), +(argList([](B, $1, $2), [](A, $1, $2), 1)))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1")));
}
RewriterStatement stmt1 = RewriterUtils.parse("+(+(A, B), 1)", ctx, "MATRIX:A,B", "LITERAL_INT:1");
RewriterStatement stmt2 = RewriterUtils.parse("+(+(B, 1), A)", ctx, "MATRIX:A,B", "LITERAL_INT:1");

@Test
public void testAdditionMatrix2() {
RewriterStatement stmt = RewriterUtils.parse("+(1, +(A, B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1");
stmt = converter.apply(stmt);
System.out.println(stmt.toParsableString(ctx, true));
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), +(argList([](B, $1, $2), [](A, $1, $2), 1)))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1")));
stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void testSubtractionFloat1() {
RewriterStatement stmt = RewriterUtils.parse("+(-(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1");
stmt = converter.apply(stmt);
stmt = canonicalConverter.apply(stmt);
System.out.println(stmt.toParsableString(ctx, true));
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(-(b), a, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1")));
}

@Test
public void testSubtractionFloat2() {
RewriterStatement stmt = RewriterUtils.parse("+(1, -(a, -(b, c)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b,c", "LITERAL_INT:0,1");
stmt = converter.apply(stmt);
stmt = canonicalConverter.apply(stmt);
System.out.println(stmt.toParsableString(ctx, true));
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(-(b), a, c, 1))", ctx, "FLOAT:a,b, c", "LITERAL_INT:0,1")));
}

@Test
// Fusion will no longer be pursued
/*@Test
public void testFusedPlanMatrixGeneration() {
RewriterStatement stmt = RewriterUtils.parse("+(1, +(A, B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1");
stmt = converter.apply(stmt);
Expand All @@ -94,7 +94,7 @@ public void testFusedPlanAdvancedAggregationGeneration() {
RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx);
System.out.println("Orig: " + stmt.toParsableString(ctx, true));
System.out.println("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true)));
}
}*/

@Test
public void testReorgEquivalence() {
Expand Down

0 comments on commit 636ea9d

Please # to comment.