forked from apache/systemds
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
320 additions
and
0 deletions.
There are no files selected for viewing
196 changes: 196 additions & 0 deletions
196
src/main/java/org/apache/sysds/hops/rewriter/DMLCodeGenerator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
package org.apache.sysds.hops.rewriter; | ||
|
||
import org.apache.commons.lang3.NotImplementedException; | ||
|
||
import java.util.HashMap; | ||
import java.util.HashSet; | ||
import java.util.Map; | ||
import java.util.Set; | ||
import java.util.function.BiFunction; | ||
import java.util.function.Function; | ||
|
||
public class DMLCodeGenerator { | ||
public static final double EPS = 1e-10; | ||
|
||
|
||
private static final HashSet<String> printAsBinary = new HashSet<>(); | ||
private static final HashMap<String, BiFunction<RewriterStatement, StringBuilder, Boolean>> customEncoders = new HashMap<>(); | ||
private static final RuleContext ctx = RewriterUtils.buildDefaultContext(); | ||
|
||
static { | ||
printAsBinary.add("+"); | ||
printAsBinary.add("-"); | ||
printAsBinary.add("*"); | ||
printAsBinary.add("/"); | ||
printAsBinary.add("^"); | ||
printAsBinary.add("=="); | ||
printAsBinary.add("!="); | ||
printAsBinary.add(">"); | ||
printAsBinary.add(">="); | ||
printAsBinary.add("<"); | ||
printAsBinary.add("<="); | ||
|
||
customEncoders.put("[]", (stmt, sb) -> { | ||
if (stmt.getOperands().size() == 3) { | ||
sb.append('('); | ||
appendExpression(stmt.getChild(0), sb); | ||
sb.append(")["); | ||
appendExpression(stmt.getChild(1), sb); | ||
sb.append(", "); | ||
appendExpression(stmt.getChild(2), sb); | ||
sb.append(']'); | ||
return true; | ||
} else if (stmt.getOperands().size() == 5) { | ||
sb.append('('); | ||
appendExpression(stmt.getChild(0), sb); | ||
sb.append(")["); | ||
appendExpression(stmt.getChild(1), sb); | ||
sb.append(" : "); | ||
appendExpression(stmt.getChild(2), sb); | ||
sb.append(", "); | ||
appendExpression(stmt.getChild(3), sb); | ||
sb.append(" : "); | ||
appendExpression(stmt.getChild(4), sb); | ||
sb.append(']'); | ||
return true; | ||
} | ||
|
||
return false; | ||
}); | ||
} | ||
|
||
public static String generateRuleValidationDML(RewriterRule rule, double eps, String sessionId) { | ||
RewriterStatement stmtFrom = rule.getStmt1(); | ||
RewriterStatement stmtTo = rule.getStmt2(); | ||
|
||
Set<RewriterStatement> vars = new HashSet<>(); | ||
|
||
stmtFrom.forEachPostOrder((stmt, pred) -> { | ||
if (!stmt.isInstruction() && !stmt.isLiteral()) | ||
vars.add(stmt); | ||
}, false); | ||
|
||
stmtTo.forEachPostOrder((stmt, pred) -> { | ||
if (!stmt.isInstruction() && !stmt.isLiteral()) | ||
vars.add(stmt); | ||
}, false); | ||
|
||
StringBuilder sb = new StringBuilder(); | ||
|
||
for (RewriterStatement var : vars) { | ||
switch (var.getResultingDataType(ctx)) { | ||
case "MATRIX": | ||
sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n"); | ||
break; | ||
case "FLOAT": | ||
sb.append(var.getId() + " = as.scalar(rand())\n"); | ||
break; | ||
case "INT": | ||
sb.append(var.getId() + " = as.integer(as.scalar(rand(min=0.0, max=10000.0)))\n"); | ||
break; | ||
case "BOOL": | ||
sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n"); | ||
break; | ||
default: | ||
throw new NotImplementedException(var.getResultingDataType(ctx)); | ||
} | ||
} | ||
|
||
sb.append('\n'); | ||
sb.append("R1 = "); | ||
sb.append(generateDML(stmtFrom)); | ||
sb.append('\n'); | ||
sb.append("R2 = "); | ||
sb.append(generateDML(stmtTo)); | ||
sb.append('\n'); | ||
sb.append("print(\""); | ||
sb.append(sessionId); | ||
sb.append(" valid: \" + ("); | ||
sb.append(generateEqualityCheck("R1", "R2", stmtFrom.getResultingDataType(ctx), eps)); | ||
sb.append("))"); | ||
|
||
return sb.toString(); | ||
} | ||
|
||
public static String generateEqualityCheck(String stmt1Var, String stmt2Var, String dataType, double eps) { | ||
switch (dataType) { | ||
case "MATRIX": | ||
return "sum(abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps + ") == length(" + stmt1Var + ")"; | ||
case "INT": | ||
case "BOOL": | ||
return stmt1Var + " == " + stmt2Var; | ||
case "FLOAT": | ||
return "abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps; | ||
} | ||
|
||
throw new NotImplementedException(); | ||
} | ||
|
||
public static String generateDMLDefs(Map<String, RewriterStatement> defs) { | ||
StringBuilder sb = new StringBuilder(); | ||
|
||
defs.forEach((k, v) -> { | ||
sb.append(k); | ||
sb.append(" = "); | ||
sb.append(generateDML(v)); | ||
sb.append('\n'); | ||
}); | ||
|
||
return sb.toString(); | ||
} | ||
|
||
public static String generateDML(RewriterStatement root) { | ||
StringBuilder sb = new StringBuilder(); | ||
appendExpression(root, sb); | ||
|
||
return sb.toString(); | ||
} | ||
|
||
private static void appendExpression(RewriterStatement cur, StringBuilder sb) { | ||
if (cur.isInstruction()) { | ||
resolveExpression((RewriterInstruction) cur, sb); | ||
} else { | ||
if (cur.isLiteral()) | ||
sb.append(cur.getLiteral()); | ||
else | ||
sb.append(cur.getId()); | ||
} | ||
} | ||
|
||
private static void resolveExpression(RewriterInstruction expr, StringBuilder sb) { | ||
String typedInstr = expr.trueTypedInstruction(ctx); | ||
String unTypedInstr = expr.trueInstruction(); | ||
|
||
if (expr.getOperands().size() == 2 && (printAsBinary.contains(typedInstr) || printAsBinary.contains(unTypedInstr))) { | ||
sb.append('('); | ||
appendExpression(expr.getChild(0), sb); | ||
sb.append(") "); | ||
sb.append(unTypedInstr); | ||
sb.append(" ("); | ||
appendExpression(expr.getChild(1), sb); | ||
sb.append(')'); | ||
return; | ||
} | ||
|
||
BiFunction<RewriterStatement, StringBuilder, Boolean> customEncoder = customEncoders.get(typedInstr); | ||
|
||
if (customEncoder == null) | ||
customEncoder = customEncoders.get(unTypedInstr); | ||
|
||
if (customEncoder == null) { | ||
sb.append(unTypedInstr); | ||
sb.append('('); | ||
|
||
for (int i = 0; i < expr.getOperands().size(); i++) { | ||
if (i != 0) | ||
sb.append(", "); | ||
|
||
appendExpression(expr.getChild(i), sb); | ||
} | ||
|
||
sb.append(')'); | ||
} else { | ||
customEncoder.apply(expr, sb); | ||
} | ||
} | ||
} |
63 changes: 63 additions & 0 deletions
63
src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package org.apache.sysds.hops.rewriter; | ||
|
||
import org.apache.sysds.api.DMLScript; | ||
|
||
import java.io.OutputStream; | ||
import java.io.PrintStream; | ||
import java.util.function.Consumer; | ||
|
||
public class DMLExecutor { | ||
private static PrintStream origPrintStream = System.out; | ||
|
||
// This cannot run in parallel | ||
public static synchronized void executeCode(String code, Consumer<String> consoleInterceptor) { | ||
try { | ||
if (consoleInterceptor != null) | ||
System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor))); | ||
|
||
DMLScript.executeScript(new String[]{"-s", code}); | ||
|
||
} catch (Exception e) { | ||
e.printStackTrace(); | ||
} | ||
|
||
if (consoleInterceptor != null) | ||
System.setOut(origPrintStream); | ||
} | ||
|
||
// Bypasses the interceptor | ||
public static void println(Object o) { | ||
origPrintStream.println(o); | ||
} | ||
|
||
private static class CustomOutputStream extends OutputStream { | ||
private PrintStream ps; | ||
private StringBuilder buffer = new StringBuilder(); | ||
private Consumer<String> lineHandler; | ||
|
||
public CustomOutputStream(PrintStream actualPrintStream, Consumer<String> lineHandler) { | ||
this.ps = actualPrintStream; | ||
this.lineHandler = lineHandler; | ||
} | ||
|
||
@Override | ||
public void write(int b) { | ||
char c = (char) b; | ||
if (c == '\n') { | ||
lineHandler.accept(buffer.toString()); | ||
buffer.setLength(0); // Clear the buffer after handling the line | ||
} else { | ||
buffer.append(c); // Accumulate characters until newline | ||
} | ||
// Handle the byte 'b', or you can write to any custom destination | ||
//ps.print((char) b); // Example: redirect to System.err | ||
} | ||
|
||
@Override | ||
public void write(byte[] b, int off, int len) { | ||
for (int i = off; i < off + len; i++) { | ||
write(b[i]); | ||
} | ||
} | ||
} | ||
} |
61 changes: 61 additions & 0 deletions
61
src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package org.apache.sysds.test.component.codegen.rewrite.functions; | ||
|
||
import org.apache.commons.lang3.mutable.MutableBoolean; | ||
import org.apache.sysds.hops.rewriter.DMLCodeGenerator; | ||
import org.apache.sysds.hops.rewriter.DMLExecutor; | ||
import org.apache.sysds.hops.rewriter.RewriterRule; | ||
import org.apache.sysds.hops.rewriter.RewriterRuleSet; | ||
import org.apache.sysds.hops.rewriter.RewriterStatement; | ||
import org.apache.sysds.hops.rewriter.RewriterUtils; | ||
import org.apache.sysds.hops.rewriter.RuleContext; | ||
import org.junit.BeforeClass; | ||
import org.junit.Test; | ||
|
||
import java.util.List; | ||
import java.util.UUID; | ||
import java.util.function.Function; | ||
|
||
public class DMLCodeGenTest { | ||
|
||
private static RuleContext ctx; | ||
private static Function<RewriterStatement, RewriterStatement> canonicalConverter; | ||
|
||
@BeforeClass | ||
public static void setup() { | ||
ctx = RewriterUtils.buildDefaultContext(); | ||
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); | ||
} | ||
|
||
@Test | ||
public void test1() { | ||
RewriterStatement stmt = RewriterUtils.parse("trace(+(A, t(B)))", ctx, "MATRIX:A,B"); | ||
System.out.println(DMLCodeGenerator.generateDML(stmt)); | ||
} | ||
|
||
@Test | ||
public void test2() { | ||
String ruleStr1 = "MATRIX:A\nt(t(A))\n=>\nA"; | ||
String ruleStr2 = "MATRIX:A\nrowSums(t(A))\n=>\nt(colSums(A))"; | ||
RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); | ||
RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); | ||
|
||
//RewriterRuleSet ruleSet = new RewriterRuleSet(ctx, List.of(rule1, rule2)); | ||
String sessionId = UUID.randomUUID().toString(); | ||
String validationScript = DMLCodeGenerator.generateRuleValidationDML(rule2, DMLCodeGenerator.EPS, sessionId); | ||
System.out.println("Validation script:"); | ||
System.out.println(validationScript); | ||
MutableBoolean valid = new MutableBoolean(true); | ||
DMLExecutor.executeCode(validationScript, line -> { | ||
if (!line.startsWith(sessionId)) | ||
return; | ||
|
||
if (!line.endsWith("valid: TRUE")) { | ||
DMLExecutor.println("An invalid rule was found!"); | ||
valid.setValue(false); | ||
} | ||
}); | ||
|
||
System.out.println("Exiting..."); | ||
assert valid.booleanValue(); | ||
} | ||
} |