Skip to content

Commit

Permalink
Update RewriterCodeGen.java
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 0961998 commit 8514192
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions src/main/java/org/apache/sysds/hops/rewriter/RewriterCodeGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import scala.Tuple2;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

public class RewriterCodeGen {
Expand Down Expand Up @@ -85,35 +87,64 @@ private static String generateRewriteFunction(RewriterRule rule, String fName, i

private static void buildMatchingSequence(RewriterStatement from, RewriterStatement to, StringBuilder sb, final RuleContext ctx, int indentation) {
Map<RewriterStatement, String> vars = new HashMap<>();
vars.put(from, "hi");
recursivelyBuildMatchingSequence(from, sb, "hi", ctx, indentation, vars);
sb.append("\n");
indent(indentation, sb);
sb.append("// Now we start building the new Hop\n");
sb.append("// Now, we start building the new Hop\n");

if (DEBUG) {
indent(indentation, sb);
sb.append("System.out.println(\"HERE\");\n");
}

buildRewrite(to, sb, vars, ctx, indentation);
Set<RewriterStatement> activeStatements = buildRewrite(to, sb, vars, ctx, indentation);

sb.append('\n');
indent(indentation, sb);
sb.append("return hi;\n");
}
sb.append("// Remove old unreferenced Hops\n");
removeUnreferencedHops(from, activeStatements, sb, vars, ctx, indentation);

private static void buildRewrite(RewriterStatement newRoot, StringBuilder sb, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation) {
recursivelyBuildNewHop(sb, newRoot, vars, ctx, indentation, 1);
sb.append('\n');
indent(indentation, sb);
sb.append("hi = " + vars.get(newRoot) + ";\n");
sb.append("ArrayList<Hop> parents = new ArrayList<>(hi.getParent());\n\n");
indent(indentation, sb);
sb.append("for ( Hop p : parents )\n");
indent(indentation + 1, sb);
sb.append("HopRewriteUtils.replaceChildReference(p, hi, " + vars.get(to) + ");\n\n");

indent(indentation, sb);
sb.append("return " + vars.get(to) + ";\n");
}

// Returns the set of all active statements after the rewrite
private static Set<RewriterStatement> buildRewrite(RewriterStatement newRoot, StringBuilder sb, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation) {
Set<RewriterStatement> visited = new HashSet<>();
recursivelyBuildNewHop(sb, newRoot, vars, ctx, indentation, 1, visited);
//indent(indentation, sb);
//sb.append("hi = " + vars.get(newRoot) + ";\n");

return visited;
}

private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation, int varCtr) {
private static void removeUnreferencedHops(RewriterStatement oldRoot, Set<RewriterStatement> activeStatements, StringBuilder sb, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation) {
oldRoot.forEachPreOrder(cur -> {
if (activeStatements.contains(cur))
return true;

indent(indentation, sb);
sb.append("HopRewriteUtils.removeAllChildReferences(" + vars.get(cur) + ");\n");
return true;
}, false);
}

private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation, int varCtr, Set<RewriterStatement> visited) {
visited.add(cur);
if (vars.containsKey(cur))
return varCtr;

for (RewriterStatement child : cur.getOperands()) {
varCtr = recursivelyBuildNewHop(sb, child, vars, ctx, indentation, varCtr);
}
for (RewriterStatement child : cur.getOperands())
varCtr = recursivelyBuildNewHop(sb, child, vars, ctx, indentation, varCtr, visited);

if (cur instanceof RewriterDataType) {
if (cur.isLiteral()) {
Expand Down

0 comments on commit 8514192

Please # to comment.