Skip to content

Commit

Permalink
Validation script implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent e153ae1 commit 8b866c6
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 0 deletions.
196 changes: 196 additions & 0 deletions src/main/java/org/apache/sysds/hops/rewriter/DMLCodeGenerator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package org.apache.sysds.hops.rewriter;

import org.apache.commons.lang3.NotImplementedException;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;

public class DMLCodeGenerator {
public static final double EPS = 1e-10;


private static final HashSet<String> printAsBinary = new HashSet<>();
private static final HashMap<String, BiFunction<RewriterStatement, StringBuilder, Boolean>> customEncoders = new HashMap<>();
private static final RuleContext ctx = RewriterUtils.buildDefaultContext();

static {
printAsBinary.add("+");
printAsBinary.add("-");
printAsBinary.add("*");
printAsBinary.add("/");
printAsBinary.add("^");
printAsBinary.add("==");
printAsBinary.add("!=");
printAsBinary.add(">");
printAsBinary.add(">=");
printAsBinary.add("<");
printAsBinary.add("<=");

customEncoders.put("[]", (stmt, sb) -> {
if (stmt.getOperands().size() == 3) {
sb.append('(');
appendExpression(stmt.getChild(0), sb);
sb.append(")[");
appendExpression(stmt.getChild(1), sb);
sb.append(", ");
appendExpression(stmt.getChild(2), sb);
sb.append(']');
return true;
} else if (stmt.getOperands().size() == 5) {
sb.append('(');
appendExpression(stmt.getChild(0), sb);
sb.append(")[");
appendExpression(stmt.getChild(1), sb);
sb.append(" : ");
appendExpression(stmt.getChild(2), sb);
sb.append(", ");
appendExpression(stmt.getChild(3), sb);
sb.append(" : ");
appendExpression(stmt.getChild(4), sb);
sb.append(']');
return true;
}

return false;
});
}

public static String generateRuleValidationDML(RewriterRule rule, double eps, String sessionId) {
RewriterStatement stmtFrom = rule.getStmt1();
RewriterStatement stmtTo = rule.getStmt2();

Set<RewriterStatement> vars = new HashSet<>();

stmtFrom.forEachPostOrder((stmt, pred) -> {
if (!stmt.isInstruction() && !stmt.isLiteral())
vars.add(stmt);
}, false);

stmtTo.forEachPostOrder((stmt, pred) -> {
if (!stmt.isInstruction() && !stmt.isLiteral())
vars.add(stmt);
}, false);

StringBuilder sb = new StringBuilder();

for (RewriterStatement var : vars) {
switch (var.getResultingDataType(ctx)) {
case "MATRIX":
sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n");
break;
case "FLOAT":
sb.append(var.getId() + " = as.scalar(rand())\n");
break;
case "INT":
sb.append(var.getId() + " = as.integer(as.scalar(rand(min=0.0, max=10000.0)))\n");
break;
case "BOOL":
sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n");
break;
default:
throw new NotImplementedException(var.getResultingDataType(ctx));
}
}

sb.append('\n');
sb.append("R1 = ");
sb.append(generateDML(stmtFrom));
sb.append('\n');
sb.append("R2 = ");
sb.append(generateDML(stmtTo));
sb.append('\n');
sb.append("print(\"");
sb.append(sessionId);
sb.append(" valid: \" + (");
sb.append(generateEqualityCheck("R1", "R2", stmtFrom.getResultingDataType(ctx), eps));
sb.append("))");

return sb.toString();
}

public static String generateEqualityCheck(String stmt1Var, String stmt2Var, String dataType, double eps) {
switch (dataType) {
case "MATRIX":
return "sum(abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps + ") == length(" + stmt1Var + ")";
case "INT":
case "BOOL":
return stmt1Var + " == " + stmt2Var;
case "FLOAT":
return "abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps;
}

throw new NotImplementedException();
}

public static String generateDMLDefs(Map<String, RewriterStatement> defs) {
StringBuilder sb = new StringBuilder();

defs.forEach((k, v) -> {
sb.append(k);
sb.append(" = ");
sb.append(generateDML(v));
sb.append('\n');
});

return sb.toString();
}

public static String generateDML(RewriterStatement root) {
StringBuilder sb = new StringBuilder();
appendExpression(root, sb);

return sb.toString();
}

private static void appendExpression(RewriterStatement cur, StringBuilder sb) {
if (cur.isInstruction()) {
resolveExpression((RewriterInstruction) cur, sb);
} else {
if (cur.isLiteral())
sb.append(cur.getLiteral());
else
sb.append(cur.getId());
}
}

private static void resolveExpression(RewriterInstruction expr, StringBuilder sb) {
String typedInstr = expr.trueTypedInstruction(ctx);
String unTypedInstr = expr.trueInstruction();

if (expr.getOperands().size() == 2 && (printAsBinary.contains(typedInstr) || printAsBinary.contains(unTypedInstr))) {
sb.append('(');
appendExpression(expr.getChild(0), sb);
sb.append(") ");
sb.append(unTypedInstr);
sb.append(" (");
appendExpression(expr.getChild(1), sb);
sb.append(')');
return;
}

BiFunction<RewriterStatement, StringBuilder, Boolean> customEncoder = customEncoders.get(typedInstr);

if (customEncoder == null)
customEncoder = customEncoders.get(unTypedInstr);

if (customEncoder == null) {
sb.append(unTypedInstr);
sb.append('(');

for (int i = 0; i < expr.getOperands().size(); i++) {
if (i != 0)
sb.append(", ");

appendExpression(expr.getChild(i), sb);
}

sb.append(')');
} else {
customEncoder.apply(expr, sb);
}
}
}
63 changes: 63 additions & 0 deletions src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.apache.sysds.hops.rewriter;

import org.apache.sysds.api.DMLScript;

import java.io.OutputStream;
import java.io.PrintStream;
import java.util.function.Consumer;

public class DMLExecutor {
private static PrintStream origPrintStream = System.out;

// This cannot run in parallel
public static synchronized void executeCode(String code, Consumer<String> consoleInterceptor) {
try {
if (consoleInterceptor != null)
System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor)));

DMLScript.executeScript(new String[]{"-s", code});

} catch (Exception e) {
e.printStackTrace();
}

if (consoleInterceptor != null)
System.setOut(origPrintStream);
}

// Bypasses the interceptor
public static void println(Object o) {
origPrintStream.println(o);
}

private static class CustomOutputStream extends OutputStream {
private PrintStream ps;
private StringBuilder buffer = new StringBuilder();
private Consumer<String> lineHandler;

public CustomOutputStream(PrintStream actualPrintStream, Consumer<String> lineHandler) {
this.ps = actualPrintStream;
this.lineHandler = lineHandler;
}

@Override
public void write(int b) {
char c = (char) b;
if (c == '\n') {
lineHandler.accept(buffer.toString());
buffer.setLength(0); // Clear the buffer after handling the line
} else {
buffer.append(c); // Accumulate characters until newline
}
// Handle the byte 'b', or you can write to any custom destination
//ps.print((char) b); // Example: redirect to System.err
}

@Override
public void write(byte[] b, int off, int len) {
for (int i = off; i < off + len; i++) {
write(b[i]);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.apache.sysds.test.component.codegen.rewrite.functions;

import org.apache.commons.lang3.mutable.MutableBoolean;
import org.apache.sysds.hops.rewriter.DMLCodeGenerator;
import org.apache.sysds.hops.rewriter.DMLExecutor;
import org.apache.sysds.hops.rewriter.RewriterRule;
import org.apache.sysds.hops.rewriter.RewriterRuleSet;
import org.apache.sysds.hops.rewriter.RewriterStatement;
import org.apache.sysds.hops.rewriter.RewriterUtils;
import org.apache.sysds.hops.rewriter.RuleContext;
import org.junit.BeforeClass;
import org.junit.Test;

import java.util.List;
import java.util.UUID;
import java.util.function.Function;

public class DMLCodeGenTest {

private static RuleContext ctx;
private static Function<RewriterStatement, RewriterStatement> canonicalConverter;

@BeforeClass
public static void setup() {
ctx = RewriterUtils.buildDefaultContext();
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false);
}

@Test
public void test1() {
RewriterStatement stmt = RewriterUtils.parse("trace(+(A, t(B)))", ctx, "MATRIX:A,B");
System.out.println(DMLCodeGenerator.generateDML(stmt));
}

@Test
public void test2() {
String ruleStr1 = "MATRIX:A\nt(t(A))\n=>\nA";
String ruleStr2 = "MATRIX:A\nrowSums(t(A))\n=>\nt(colSums(A))";
RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx);
RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx);

//RewriterRuleSet ruleSet = new RewriterRuleSet(ctx, List.of(rule1, rule2));
String sessionId = UUID.randomUUID().toString();
String validationScript = DMLCodeGenerator.generateRuleValidationDML(rule2, DMLCodeGenerator.EPS, sessionId);
System.out.println("Validation script:");
System.out.println(validationScript);
MutableBoolean valid = new MutableBoolean(true);
DMLExecutor.executeCode(validationScript, line -> {
if (!line.startsWith(sessionId))
return;

if (!line.endsWith("valid: TRUE")) {
DMLExecutor.println("An invalid rule was found!");
valid.setValue(false);
}
});

System.out.println("Exiting...");
assert valid.booleanValue();
}
}

0 comments on commit 8b866c6

Please # to comment.