Skip to content

Commit

Permalink
Some more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 4f3a9b7 commit 3bdc01c
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 34 deletions.
14 changes: 7 additions & 7 deletions src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,18 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen

if (ints[0] != null && ints[1] != null) {
String literalString = Long.toString(ints[1] - ints[0] + 1);
root.unsafePutMeta("nrow", RewriterUtils.parse(literalString, ctx, "LITERAL_INT:" + literalString));
root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse(literalString, ctx, "LITERAL_INT:" + literalString), ctx));
} else {
HashMap<String, RewriterStatement> subStmts = new HashMap<>();
subStmts.put("i1", root.getOperands().get(2));
subStmts.put("i0", root.getOperands().get(1));

if (ints[0] != null) {
root.unsafePutMeta("nrow", RewriterUtils.parse("+(argList(i1, " + (1 - ints[0]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[0])));
root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i1, " + (1 - ints[0]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[0])), ctx));
} else if (ints[1] != null) {
root.unsafePutMeta("nrow", RewriterUtils.parse("+(argList(" + (ints[1] + 1) + ", -(i0)))", ctx, subStmts, "LITERAL_INT:" + (ints[1] + 1)));
root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(" + (ints[1] + 1) + ", -(i0)))", ctx, subStmts, "LITERAL_INT:" + (ints[1] + 1)), ctx));
} else {
root.unsafePutMeta("nrow", RewriterUtils.parse("+(argList(+(argList(i1, -(i0))), 1))", ctx, subStmts, "LITERAL_INT:1"));
root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i1, -(i0), 1))", ctx, subStmts, "LITERAL_INT:1"), ctx));
}
}

Expand All @@ -262,11 +262,11 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
subStmts.put("i3", root.getOperands().get(4));
subStmts.put("i2", root.getOperands().get(3));
if (ints[2] != null) {
root.unsafePutMeta("ncol", RewriterUtils.parse("+(argList(i3, " + (1 - ints[2]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[2])));
root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i3, " + (1 - ints[2]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[2])), ctx));
} else if (ints[3] != null) {
root.unsafePutMeta("ncol", RewriterUtils.parse("+(argList(" + (ints[3] + 1) + ", -(i2)))", ctx, subStmts, "LITERAL_INT:" + (ints[3] + 1)));
root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(" + (ints[3] + 1) + ", -(i2)))", ctx, subStmts, "LITERAL_INT:" + (ints[3] + 1)), ctx));
} else {
root.unsafePutMeta("ncol", RewriterUtils.parse("+(argList(+(argList(i3, -(i2))), 1))", ctx, subStmts, "LITERAL_INT:1"));
root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i3, -(i2), 1))", ctx, subStmts, "LITERAL_INT:1"), ctx));
}
}

Expand Down
12 changes: 10 additions & 2 deletions src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,17 @@ public boolean match(final MatcherContext mCtx) {

// Now, match those statements
mCtx.currentStatement = ncolEquivThat;
if (!ncolEquiv.match(mCtx))
if (!ncolEquiv.match(mCtx)) {
if (mCtx.isDebug())
System.out.println("MismatchNcolEquiv: " + ncolEquiv + " <=> " + ncolEquivThat);
return false;
}
mCtx.currentStatement = nrowEquivThat;
if (!nrowEquiv.match(mCtx))
if (!nrowEquiv.match(mCtx)) {
if (mCtx.isDebug())
System.out.println("MismatchNrowEquiv: " + nrowEquiv + " <=> " + nrowEquivThat);
return false;
}
}
}
}
Expand All @@ -331,6 +337,8 @@ public boolean match(final MatcherContext mCtx) {
if (assoc == null) {
if (!mCtx.allowDuplicatePointers && mCtx.getDependencyMap().containsValue(stmt)) {
mCtx.setFirstMismatch(this, stmt);
if (mCtx.isDebug())
System.out.println("MismatchAssocNull: " + stmt);
return false; // Then the statement variable is already associated with another variable
}
mCtx.getDependencyMap().put(this, stmt);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.apache.sysds.hops.rewriter;

import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions;
import org.apache.sysds.hops.rewriter.utils.RewriterUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -862,7 +863,9 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi

RewriterStatement aRef = stmt.getChild(0, 1, 0);
RewriterStatement bRef = stmt.getChild(1, 1, 0);
match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNCol(), bRef.getNRow(), match.getNewExprRoot());
RewriterAssertions assertions = match.getNewExprRoot().getAssertions(ctx);
assertions.addEqualityAssertion(aRef.getNCol(), bRef.getNRow(), match.getNewExprRoot());
assertions.update(match.getNewExprRoot());
}, true) // Assumes it will never collide
.apply(hooks.get(5).getId(), stmt -> {
UUID id = UUID.randomUUID();
Expand Down Expand Up @@ -1041,14 +1044,6 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
.parseGlobalVars("LITERAL_INT:1")
.withParsedStatement("rowSums(A)", hooks)
.toParsedStatement("$3:_m($1:_idx(1, nrow(A)), 1, sum($4:_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2))))", hooks)
.iff(match -> {
RewriterStatement meta = (RewriterStatement) match.getMatchRoot().getOperands().get(0).getMeta("ncol");

if (meta == null)
throw new IllegalArgumentException("Column meta should not be null: " + match.getMatchRoot().getOperands().get(0).toString(ctx));

return !meta.isLiteral() || ((long)meta.getLiteral()) != 1;
}, true)
.apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide
.apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true)
.apply(hooks.get(3).getId(), stmt -> {
Expand All @@ -1071,14 +1066,6 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
.parseGlobalVars("LITERAL_INT:1")
.withParsedStatement("colSums(A)", hooks)
.toParsedStatement("$3:_m(1, $1:_idx(1, ncol(A)), sum($4:_m($2:_idx(1, nrow(A)), 1, [](A, $2, $1))))", hooks)
.iff(match -> {
RewriterStatement meta = (RewriterStatement) match.getMatchRoot().getOperands().get(0).getMeta("ncol");

if (meta == null)
throw new IllegalArgumentException("Column meta should not be null: " + match.getMatchRoot().getOperands().get(0).toString(ctx));

return !meta.isLiteral() || ((long)meta.getLiteral()) != 1;
}, true)
.apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide
.apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true)
.apply(hooks.get(3).getId(), stmt -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ public void testSimplifyBushyBinaryOperation() {
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));
assert RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).debug(true).match();
// Here the sort algorithm is unstable
assert false;
}

@Test
Expand Down Expand Up @@ -158,7 +160,7 @@ public void testSimplifyTraceMatrixMult() {
@Test
public void testSimplifySlicedMatrixMult() {
RewriterStatement stmt1 = RewriterUtils.parse("[](%*%(A,B), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1");
RewriterStatement stmt2 = RewriterUtils.parse("%*%(colVec(A), rowVec(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1");
RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(%*%(colVec(A), rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1");

assert match(stmt1, stmt2);
}
Expand Down Expand Up @@ -332,28 +334,32 @@ public void testSimplifyEmptyReorgOperation() {
assert match(stmt1, stmt2);
}

// This is a hacky workaround
@Test
public void testSimplifyEmptyMatrixMult() {
// We emulate an empty matrix by multiplying by zero
// Note that we pass the dimension info of the matrix multiply to get the same e-class assertions
RewriterStatement stmt1 = RewriterUtils.parse("%*%(*(A, 0.0), B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("const(%*%(A, B), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");

assert match(stmt1, stmt2);
// We need to explicitly assert A and B
stmt2.givenThatEqual(stmt2.getChild(0, 1).getNRow(), stmt2.getChild(0, 0).getNCol(), ctx);

assert match(stmt1, stmt2, true);
}

@Test
public void testSimplifyEmptyMatrixMult2() {
RewriterStatement stmt1 = RewriterUtils.parse("%*%(A, cast.MATRIX(1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("rowVec(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt1 = RewriterUtils.parse("%*%(rowVec(A), cast.MATRIX(1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("rowVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");

assert match(stmt1, stmt2);
}

@Test
public void testSimplifyScalarMatrixMult() {
RewriterStatement stmt1 = RewriterUtils.parse("%*%(A, cast.MATRIX(a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("*(A, as.scalar(cast.MATRIX(a)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt1 = RewriterUtils.parse("%*%(rowVec(A), cast.MATRIX(a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("*(rowVec(A), as.scalar(cast.MATRIX(a)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");

assert match(stmt1, stmt2);
}
Expand Down Expand Up @@ -469,7 +475,7 @@ public void testSimplifyEmptyBinaryOperation3() {
assert match(stmt1, stmt2);
}

@Test
//@Test
public void testSimplifyScalarMVBinaryOperation() {
RewriterStatement stmt1 = RewriterUtils.parse("*(A, rowVec(colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
RewriterStatement stmt2 = RewriterUtils.parse("*(A, as.scalar(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
Expand Down Expand Up @@ -516,13 +522,17 @@ public void testSimplifyNrowNcolComputation3() {
}

private boolean match(RewriterStatement stmt1, RewriterStatement stmt2) {
return match(stmt1, stmt2, false);
}

private boolean match(RewriterStatement stmt1, RewriterStatement stmt2, boolean debug) {
stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));
return RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).match();
return RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).debug(debug).match();
}
}

0 comments on commit 3bdc01c

Please # to comment.