Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 6e037e2 commit 9c523e1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
51 changes: 49 additions & 2 deletions src/main/java/org/apache/sysds/hops/rewriter/DMLCodeGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St

StringBuilder sb = new StringBuilder();

for (RewriterStatement var : vars) {
sb.append(generateDMLVariables(vars));
/*for (RewriterStatement var : vars) {
switch (var.getResultingDataType(ctx)) {
case "MATRIX":
sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n");
Expand All @@ -113,7 +114,7 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
default:
throw new NotImplementedException(var.getResultingDataType(ctx));
}
}
}*/

sb.append('\n');
sb.append("R1 = ");
Expand All @@ -131,6 +132,41 @@ public static String generateRuleValidationDML(RewriterRule rule, double eps, St
return sb.toString();
}

public static String generateDMLVariables(RewriterStatement root) {
Set<RewriterStatement> vars = new HashSet<>();
root.forEachPostOrder((stmt, pred) -> {
if (!stmt.isInstruction() && !stmt.isLiteral())
vars.add(stmt);
}, false);

return generateDMLVariables(vars);
}

public static String generateDMLVariables(Set<RewriterStatement> vars) {
StringBuilder sb = new StringBuilder();

for (RewriterStatement var : vars) {
switch (var.getResultingDataType(ctx)) {
case "MATRIX":
sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n");
break;
case "FLOAT":
sb.append(var.getId() + " = as.scalar(rand())\n");
break;
case "INT":
sb.append(var.getId() + " = as.integer(as.scalar(rand(min=0.0, max=10000.0)))\n");
break;
case "BOOL":
sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n");
break;
default:
throw new NotImplementedException(var.getResultingDataType(ctx));
}
}

return sb.toString();
}

public static String generateEqualityCheck(String stmt1Var, String stmt2Var, String dataType, double eps) {
switch (dataType) {
case "MATRIX":
Expand All @@ -145,6 +181,17 @@ public static String generateEqualityCheck(String stmt1Var, String stmt2Var, Str
throw new NotImplementedException();
}

public static String generateDMLDefs(RewriterStatement stmt) {
Map<String, RewriterStatement> vars = new HashMap<>();

stmt.forEachPostOrder((cur, pred) -> {
if (!cur.isInstruction() && !cur.isLiteral())
vars.put(cur.getId(), cur);
}, false);

return generateDMLDefs(vars);
}

public static String generateDMLDefs(Map<String, RewriterStatement> defs) {
StringBuilder sb = new StringBuilder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final R
MutableBoolean isValid = new MutableBoolean(false);
DMLExecutor.executeCode(code, DMLCodeGenerator.ruleValidationScript(sessionId, isValid::setValue));

String code2 = DMLCodeGenerator.generateDML(rule.getStmt1());
String code2Header = DMLCodeGenerator.generateDMLVariables(rule.getStmt1());
String code2 = code2Header + "\nresult = " + DMLCodeGenerator.generateDML(rule.getStmt1()) + "\nprint(lineage(result))";
RewriterRuntimeUtils.attachHopInterceptor(prog -> {
DMLExecutor.println("HERE");
DMLExecutor.println(prog.getStatementBlocks().get(0).getHops().get(0).getInput(0));
DMLExecutor.println(prog.getStatementBlocks().get(0).getHops().get(0).getInput(0).getInput(0));
List<RewriterStatement> topLevelStmts = RewriterRuntimeUtils.getTopLevelHops(prog, ctx);
DMLExecutor.println(topLevelStmts);
// TODO: Evaluate cost and if our rule can still be applied
Expand Down

0 comments on commit 9c523e1

Please # to comment.