Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 605d1bd commit b90a4a4
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1195,25 +1195,27 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
);
});

// ifelse expression pullup
rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup")
.setUnidirectional(true)
.parseGlobalVars("FLOAT:a,c,d")
.parseGlobalVars("BOOL:b")
.withParsedStatement("$1:ElementWiseInstruction(ifelse(b, a, c), d)", hooks)
.toParsedStatement("ifelse(b, $2:ElementWiseInstruction(a, d), $3:ElementWiseInstruction(c, d))", hooks)
.linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true)
.build()
);
rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup")
.setUnidirectional(true)
.parseGlobalVars("FLOAT:a,c,d")
.parseGlobalVars("BOOL:b")
.withParsedStatement("$1:ElementWiseInstruction(d, ifelse(b, a, c))", hooks)
.toParsedStatement("ifelse(b, $2:ElementWiseInstruction(d, a), $3:ElementWiseInstruction(d, c))", hooks)
.linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true)
.build()
);
RewriterUtils.buildBinaryPermutations(SCALARS, (t1, t2) -> {
// ifelse expression pullup
rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup")
.setUnidirectional(true)
.parseGlobalVars("FLOAT:a,c,d")
.parseGlobalVars("BOOL:b")
.withParsedStatement("$1:ElementWiseInstruction(ifelse(b, a, c), d)", hooks)
.toParsedStatement("ifelse(b, $2:ElementWiseInstruction(a, d), $3:ElementWiseInstruction(c, d))", hooks)
.linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true)
.build()
);
rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup")
.setUnidirectional(true)
.parseGlobalVars("FLOAT:a,c,d")
.parseGlobalVars("BOOL:b")
.withParsedStatement("$1:ElementWiseInstruction(d, ifelse(b, a, c))", hooks)
.toParsedStatement("ifelse(b, $2:ElementWiseInstruction(d, a), $3:ElementWiseInstruction(d, c))", hooks)
.linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true)
.build()
);
});

rules.add(new RewriterRuleBuilder(ctx, "Ifelse branch merge")
.setUnidirectional(true)
Expand All @@ -1224,6 +1226,8 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
.build()
);

// Pushdown successive

rules.add(new RewriterRuleBuilder(ctx)
.setUnidirectional(true)
.parseGlobalVars("INT:l")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,7 @@ public static Function<RewriterStatement, RewriterStatement> buildCanonicalFormC
}
RewriterUtils.mergeArgLists(stmt, ctx);
stmt = RewriterUtils.pullOutConstants(stmt, ctx);
cleanupUnecessaryIndexExpressions(stmt, ctx);
stmt.prepareForHashing();
stmt.recomputeHashCodes(ctx);

Expand Down Expand Up @@ -1442,7 +1443,7 @@ private static RewriterStatement pullOutConstantsRecursively(RewriterStatement c
}

private static RewriterStatement tryPullOutSum(RewriterStatement sum, final RuleContext ctx) {
// TODO: What happens on multi-index? Then, some unnecessary indices wil currently not be pulled out
// TODO: What happens on multi-index? Then, some unnecessary indices will currently not be pulled out
RewriterStatement idxExpr = sum.getChild(0);
UUID ownerId = (UUID) idxExpr.getMeta("ownerId");
RewriterStatement sumBody = idxExpr.getChild(1);
Expand Down Expand Up @@ -1706,6 +1707,47 @@ private static RewriterStatement foldUnary(RewriterStatement stmt, final RuleCon
return stmt;
}

public static void cleanupUnecessaryIndexExpressions(RewriterStatement stmt, final RuleContext ctx) {
stmt.forEachPostOrder((cur, pred) -> {
if (!cur.isInstruction() || !cur.trueInstruction().equals("sum"))
return;

cur = cur.getChild(0);

if (!cur.isInstruction() || !cur.trueInstruction().equals("_idxExpr"))
return;

if (!cur.getChild(1).isInstruction() || !cur.getChild(1).trueInstruction().equals("ifelse") || !cur.getChild(1,2).isLiteral() || cur.getChild(1,2).floatLiteral() != 0.0D)
return;

RewriterStatement query = cur.getChild(1, 0);

if (query.isInstruction() && query.trueInstruction().equals("==")) {
RewriterStatement idx1 = query.getChild(0);
RewriterStatement idx2 = query.getChild(1);

if (idx1.isInstruction() && idx2.isInstruction() && idx1.trueInstruction().equals("_idx") && idx2.trueInstruction().equals("_idx")) {
List<RewriterStatement> indices = cur.getChild(0).getOperands();
if (indices.contains(idx1)) {
boolean removed = indices.remove(idx2);

if (removed) {
cur.getOperands().set(1, cur.getChild(1, 1));
cur.getChild(1).forEachPreOrder(cur2 -> {
for (int i = 0; i < cur2.getOperands().size(); i++) {
if (cur2.getChild(i).equals(idx2))
cur2.getOperands().set(i, idx1);
}

return true;
}, true);
}
}
}
}
}, false);
}

public static RewriterStatement doCSE(RewriterStatement stmt, final RuleContext ctx) {
throw new NotImplementedException();
}
Expand Down

0 comments on commit b90a4a4

Please # to comment.