Skip to content

Commit

Permalink
Update RewriterRuntimeUtils.java
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent da99067 commit 26276d8
Showing 1 changed file with 55 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -106,16 +114,56 @@ public static RewriterStatement buildDAGFromHop(Hop hop, int maxDepth, final Rul
}

public static void forAllUniqueTranslatableStatements(DMLProgram program, int maxDepth, Consumer<RewriterStatement> stmt, RewriterDatabase db, final RuleContext ctx) {
for (StatementBlock sb : program.getStatementBlocks()) {
if (sb.getHops() != null)
sb.getHops().forEach(hop -> forAllUniqueTranslatableStatements(hop, maxDepth, stmt, new HashSet<>(), db, ctx));
Set<Hop> visited = new HashSet<>();

if (sb.getStatements() != null) {
for (Statement s : sb.getStatements()) {
// TODO: Handle
}
for (String namespaceKey : program.getNamespaces().keySet()) {
for (String fname : program.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fsblock = program.getFunctionStatementBlock(namespaceKey, fname);
handleStatementBlock(fsblock, maxDepth, stmt, visited, db, ctx);
}
}

for (StatementBlock sb : program.getStatementBlocks()) {
handleStatementBlock(sb, maxDepth, stmt, visited, db, ctx);
}
}

private static void handleStatementBlock(StatementBlock sb, int maxDepth, Consumer<RewriterStatement> consumer, Set<Hop> visited, RewriterDatabase db, final RuleContext ctx) {
if (sb instanceof FunctionStatementBlock)
{
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
fstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx));
}
else if (sb instanceof WhileStatementBlock)
{
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
forAllUniqueTranslatableStatements(wsb.getPredicateHops(), maxDepth, consumer, visited, db, ctx);
wstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx));
}
else if (sb instanceof IfStatementBlock)
{
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement)isb.getStatement(0);
forAllUniqueTranslatableStatements(isb.getPredicateHops(), maxDepth, consumer, visited, db, ctx);
istmt.getIfBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx));
istmt.getElseBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx));
}
else if (sb instanceof ForStatementBlock)
{
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
forAllUniqueTranslatableStatements(fsb.getFromHops(), maxDepth, consumer, visited, db, ctx);
forAllUniqueTranslatableStatements(fsb.getToHops(), maxDepth, consumer, visited, db, ctx);
forAllUniqueTranslatableStatements(fsb.getIncrementHops(), maxDepth, consumer, visited, db, ctx);
fstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx));
}
else
{
if (sb.getHops() != null)
sb.getHops().forEach(hop -> forAllUniqueTranslatableStatements(hop, maxDepth, consumer, visited, db, ctx));
}
}

private static void forAllUniqueTranslatableStatements(Hop currentHop, int maxDepth, Consumer<RewriterStatement> consumer, Set<Hop> visited, RewriterDatabase db, final RuleContext ctx) {
Expand Down

0 comments on commit 26276d8

Please # to comment.