Skip to content

Commit

Permalink
Some more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 7d6299d commit 4078d34
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ public static boolean isNeutralElement(Object num, String op) {
return false;
}

public static boolean isNegNeutral(Object num, String op) {
switch (op) {
case "*":
return num.equals(-1L) || num.equals(-1.0D);
}

return false;
}

public static boolean cancelOutNary(String op, List<RewriterStatement> stmts) {
Set<Integer> toRemove = new HashSet<>();
switch (op) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ public static String getDefaultContextString() {

builder.append("rand(INT,INT,FLOAT,FLOAT)::MATRIX\n"); // Args: rows, cols, min, max
builder.append("rand(INT,INT)::FLOAT\n"); // Just to make it possible to say that random is dependent on both matrix indices
builder.append("rand(INT...)::FLOAT\n");
builder.append("matrix(INT,INT,INT)::MATRIX\n");

builder.append("trace(MATRIX)::FLOAT\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ public boolean match(final MatcherContext mCtx) {
return true;
}

if (mCtx.isDebug())
System.out.println("MismatchAssoc: " + stmt + " <=> " + assoc);

mCtx.setFirstMismatch(this, stmt);
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,11 @@ else if (mismatchCtr > 1)
for (int i = 0; i < s; i++) {
mCtx.currentStatement = inst.operands.get(i);

if (!operands.get(i).match(mCtx))
if (!operands.get(i).match(mCtx)) {
if (mCtx.isDebug())
System.out.println("Mismatch: " + operands.get(i) + " <=> " + inst.operands.get(i));
return false;
}
}

mCtx.getInternalReferences().put(this, stmt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,17 @@ public static void canonicalizeBooleanStatements(final List<RewriterRule> rules,
public static void expandStreamingExpressions(final List<RewriterRule> rules, final RuleContext ctx) {
HashMap<Integer, RewriterStatement> hooks = new HashMap<>();

// cast.MATRIX
rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix")
.setUnidirectional(true)
.parseGlobalVars("FLOAT:a")
.parseGlobalVars("MATRIX:A")
.parseGlobalVars("LITERAL_INT:1")
.withParsedStatement("cast.MATRIX(a)", hooks)
.toParsedStatement("$4:_m(1, 1, a)", hooks)
.build()
);

// Const
rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix")
.setUnidirectional(true)
Expand Down Expand Up @@ -970,7 +981,7 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
.parseGlobalVars("INT:n,m")
.parseGlobalVars("FLOAT:a,b")
.withParsedStatement("rand(n, m, a, b)", hooks)
.toParsedStatement("$3:_m($1:_idx(1, n), $2:_idx(1, m), +(a, *(+(b, -(a)), rand($1, $2))))", hooks)
.toParsedStatement("$3:_m($1:_idx(1, n), $2:_idx(1, m), +(a, *(+(b, -(a)), rand(argList($1,$2)))))", hooks)
.apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide
.apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true)
.apply(hooks.get(3).getId(), stmt -> {
Expand Down Expand Up @@ -1406,9 +1417,7 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
.build()
);


// TODO: Deal with boolean or int matrices
rules.add(new RewriterRuleBuilder(ctx, "_m(i::<const>, j::<const>, v) => cast.MATRIX(v)")
/*rules.add(new RewriterRuleBuilder(ctx, "_m(i::<const>, j::<const>, v) => cast.MATRIX(v)")
.setUnidirectional(true)
.parseGlobalVars("MATRIX:A,B")
.parseGlobalVars("INT:i,j")
Expand All @@ -1424,7 +1433,7 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
return matching;
}, true)
.build()
);
);*/

rules.add(new RewriterRuleBuilder(ctx, "_idx(a,a) => a")
.setUnidirectional(true)
Expand Down Expand Up @@ -1493,7 +1502,6 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
);

RewriterUtils.buildBinaryPermutations(List.of("FLOAT"), (t1, t2) -> {
// TODO: This probably first requires pulling out invariants of this idxExpr
rules.add(new RewriterRuleBuilder(ctx, "*(sum(_idxExpr(i, ...)), sum(_idxExpr(j, ...))) => _idxExpr(i, _idxExpr(j, sum(*(...)))")
.setUnidirectional(true)
.parseGlobalVars("MATRIX:A,B")
Expand Down Expand Up @@ -1722,6 +1730,17 @@ public static void buildElementWiseAlgebraicCanonicalization(final List<Rewriter
.build()
);
});

List.of("FLOAT", "INT").forEach(t -> {
rules.add(new RewriterRuleBuilder(ctx, "-(a) => *(-1.0, a)")
.setUnidirectional(true)
.parseGlobalVars(t + ":a")
.parseGlobalVars("LITERAL_" + t + ":-1")
.withParsedStatement("-(a)")
.toParsedStatement("*(-1, a)")
.build()
);
});
}

@Deprecated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ public MatcherContext debug(boolean debug) {
}

public boolean match() {
return matchRoot.match(this);
return thisExpressionRoot.match(this);
}

public boolean isDebug() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// For now, we assume that _argList() will have one unique parent
public class TopologicalSort {
public static boolean DEBUG = false;
private static final Set<String> SORTABLE_ARGLIST_OPS = Set.of("+", "-", "*", "_idxExpr", "_EClass");
private static final Set<String> SORTABLE_ARGLIST_OPS = Set.of("+", "-", "*", "_idxExpr", "_EClass", "rand");
private static final Set<String> SORTABLE_OPS = Set.of("==", "!=");

// TODO: Sort doesn't work if we have sth like _EClass(argList(nrow(U), nrow(V)), as the lowest address will be nrow, ncol and not U, V
Expand All @@ -30,14 +30,12 @@ public static void sort(RewriterStatement root, final RuleContext ctx) {
}, ctx);
}

// TODO: Fails for E_Classes in DataTypes (matrix) if they do not occur elsewhere
public static void sort(RewriterStatement root, BiFunction<RewriterStatement, RewriterStatement, Boolean> isArrangable, final RuleContext ctx) {
List<RewriterStatement> uncertainParents = setupOrderFacts(root, isArrangable, ctx);

//Set<UnorderedSet> lowestUncertainties = findLowestUncertainties(root);
//setupAddresses(lowestUncertainties);
buildAddresses(root, ctx);
resolveAmbiguities(root, ctx, uncertainParents);
// TODO: Propagate address priorities and thus implicit orderings up the DAG
resetAddresses(uncertainParents);

int factCtr = 0;
Expand All @@ -52,8 +50,6 @@ public static void sort(RewriterStatement root, BiFunction<RewriterStatement, Re
System.out.println("Lowest uncertainties: " + lowestUncertainties);
}

// TODO: Don't introduce the facts to the lowest uncertainties but their leaves first
// TODO: This avoids stuff like argList(ncol(U), ncol(V))
factCtr = introduceFacts(lowestUncertainties, factCtr);
buildAddresses(root, ctx);

Expand All @@ -68,7 +64,6 @@ public static void sort(RewriterStatement root, BiFunction<RewriterStatement, Re

resolveAmbiguities(root, ctx, uncertainParents);
resetAddresses(uncertainParents);
// TODO: Propagate address priorities and thus implicit orderings up the DAG

lowestUncertainties = findLowestUncertainties(root);
ctr++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,12 @@ private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final
return argList.get(0);
else if (argList.isEmpty())
return neutral;
}
} /*else if (argList.size() == 2 && ConstantFoldingFunctions.isNegNeutral(argList.get(literals[0]).getLiteral(), stmt.trueInstruction())) {
RewriterStatement neutral = argList.get(literals[0]);
argList.remove(literals[0]);
return new RewriterInstruction("-", ctx, argList.get(0));
}*/
}

if (literals.length < 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ public void testSimplifyBushyBinaryOperation() {
RewriterStatement stmt1 = RewriterUtils.parse("*(A,*(B, %*%(C, rowVec(D))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0");
RewriterStatement stmt2 = RewriterUtils.parse("*(*(A,B), %*%(C, rowVec(D)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0");

assert match(stmt1, stmt2);
stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));
assert RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).debug(true).match();
}

@Test
Expand Down Expand Up @@ -273,14 +280,6 @@ public void testSimplifyColwiseAggregate() {
assert match(stmt1, stmt2);
}

@Test
public void testSimplifyColwiseAggregate2() {
RewriterStatement stmt1 = RewriterUtils.parse("colSums(rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("cast.MATRIX(rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");

assert match(stmt1, stmt2);
}

@Test
public void testSimplifyRowwiseAggregate() {
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
Expand All @@ -289,26 +288,20 @@ public void testSimplifyRowwiseAggregate() {
assert match(stmt1, stmt2);
}

@Test
public void testSimplifyRowwiseAggregate2() {
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("cast.MATRIX(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");

assert match(stmt1, stmt2);
}

// We don't have broadcasting semantics
@Test
public void testSimplifyColSumsMVMult() {
RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(A, colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(colVec(B)), A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(rowVec(A), rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(rowVec(B)), rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");

assert match(stmt1, stmt2);
}

// We don't have broadcasting semantics
@Test
public void testSimplifyRowSumsMVMult() {
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(A, rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("%*%(A, rowVec(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(colVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("%*%(colVec(A), t(colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");

assert match(stmt1, stmt2);
}
Expand Down

0 comments on commit 4078d34

Please # to comment.