Skip to content

Commit

Permalink
Some more bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 2439b0f commit 78a5561
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1312,9 +1312,7 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
}, true)
.linkUnidirectional(hooks.get(1).getId(), hooks.get(2).getId(), lnk -> {
RewriterStatement.transferMeta(lnk);
System.out.println("HERE");

// TODO: Big issue when having multiple references to the same sub-dag
for (int idx = 0; idx < 2; idx++) {
RewriterStatement oldRef = lnk.oldStmt.getChild(idx);

Expand All @@ -1340,10 +1338,7 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
}
}, true)
.apply(hooks.get(3).getId(), stmt -> {
System.out.println("BEFORE: " + stmt.toParsableString(ctx));
stmt.getOperands().set(0, stmt.getChild(0, 2));
System.out.println("AFTER: " + stmt.toParsableString(ctx));
System.out.println("Cnt: " + stmt.getChild(0).refCtr);
}, true)
.build()
);
Expand Down Expand Up @@ -1379,6 +1374,21 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, mStmtC).consolidate(ctx);
final RewriterStatement newStmt = RewriterUtils.foldConstants(mStmt, ctx);

/*UUID oldRefId = (UUID)oldRef.getMeta("idxId");
RewriterStatement newOne = RewriterUtils.replaceReferenceAware(lnk.newStmt.get(0).getChild(2), stmt -> {
UUID idxId = (UUID) stmt.getMeta("idxId");
if (idxId != null) {
if (idxId.equals(oldRefId))
return newStmt;
}
return null;
});
if (newOne != null)
lnk.newStmt.get(0).getOperands().set(2, newOne);*/

// Replace all references to h with
lnk.newStmt.get(0).getOperands().get(2).forEachPostOrder((el, pred) -> {
for (int i = 0; i < el.getOperands().size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,11 @@ public static RewriterStatement replaceReferenceAware(RewriterStatement root, bo
RewriterStatement newSub = replaceReferenceAware(root.getOperands().get(i), duplicateReferences, comparer, visited);

if (newSub != null) {
System.out.println("NewSub: " + newSub);
//System.out.println("NewSub: " + newSub);
if (duplicateReferences && newOne == null) {
root = root.copyNode();
newOne = root;
System.out.println("Duplication required: " + root);
//System.out.println("Duplication required: " + root);
}

root.getOperands().set(i, newSub);
Expand All @@ -844,6 +844,27 @@ public static RewriterStatement replaceReferenceAware(RewriterStatement root, bo
return newOne;
}

public static void unfoldExpressions(RewriterStatement root, RuleContext ctx) {
for (int i = 0; i < root.getOperands().size(); i++) {
RewriterStatement child = root.getChild(i);
if (child.isInstruction() && child.refCtr > 1) {
if (!child.trueInstruction().equals("_idx")
&& !child.trueInstruction().equals("_m")
&& !child.trueInstruction().equals("idxExpr")
//&& !child.trueInstruction().equals("argList")
&& !child.trueInstruction().equals("_EClass")) {
RewriterStatement cpy = child.copyNode();
root.getOperands().set(i, cpy);
child.refCtr--;
cpy.getOperands().forEach(op -> op.refCtr++);
//System.out.println("Copied: " + child.trueInstruction());
}
}

unfoldExpressions(child, ctx);
}
}

// Function to check if two lists match
public static <T> boolean findMatchingOrderings(List<T> col1, List<T> col2, T[] stack, BiFunction<T, T, Boolean> matcher, Function<T[], Boolean> permutationEmitter, boolean symmetric) {
if (col1.size() != col2.size())
Expand Down Expand Up @@ -1406,6 +1427,8 @@ public static Function<RewriterStatement, RewriterStatement> buildCanonicalFormC
RewriterUtils.mergeArgLists(stmt, ctx);

stmt = stmt.getAssertions(ctx).cleanupEClasses(stmt);
unfoldExpressions(stmt, ctx);
stmt.prepareForHashing();

// TODO: After this, stuff like CSE, A-A = 0, etc. must still be applied

Expand Down Expand Up @@ -1844,7 +1867,15 @@ private static void postCleanupIndexExpr(RewriterStatement cur) {

if (idx1.isInstruction() && idx2.isInstruction() && idx1.trueInstruction().equals("_idx") && idx2.trueInstruction().equals("_idx")) {
// Then we just choose the first index
cur.forEachPreOrder(cur2 -> {
cur.getChild(1).forEachPreOrder(cur2 -> {
for (int i = 0; i < cur2.getOperands().size(); i++) {
if (cur2.getChild(i).equals(idx2))
cur2.getOperands().set(i, idx1);
}

return true;
}, true);
cur.getChild(2).forEachPreOrder(cur2 -> {
for (int i = 0; i < cur2.getOperands().size(); i++) {
if (cur2.getChild(i).equals(idx2))
cur2.getOperands().set(i, idx1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ public void testExactMatch() {
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));

assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
assert !stmt2.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2));
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
assert stmt2.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2));
}

@Test
Expand Down

0 comments on commit 78a5561

Please # to comment.