Skip to content

Commit

Permalink
Some bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent a32da49 commit 6411b5a
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 8 deletions.
15 changes: 11 additions & 4 deletions src/main/java/org/apache/sysds/hops/rewriter/DMLExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,25 @@
public class DMLExecutor {
private static PrintStream origPrintStream = System.out;

public static synchronized void executeCode(String code) {
executeCode(code, s -> {});
public static synchronized void executeCode(String code, boolean intercept, String... additionalArgs) {
executeCode(code, intercept ? s -> {} : null, additionalArgs);
}

// TODO: We will probably need some kind of watchdog
// This cannot run in parallel
public static synchronized void executeCode(String code, Consumer<String> consoleInterceptor) {
public static synchronized void executeCode(String code, Consumer<String> consoleInterceptor, String... additionalArgs) {
try {
if (consoleInterceptor != null)
System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor)));

DMLScript.executeScript(new String[]{"-s", code});
String[] args = new String[additionalArgs.length + 2];

for (int i = 0; i < additionalArgs.length; i++)
args[i] = additionalArgs[i];

args[additionalArgs.length] = "-s";
args[additionalArgs.length + 1] = code;
DMLScript.executeScript(args);

} catch (Exception e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.rewrite.HopRewriteRule;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;

import java.io.IOException;
Expand Down Expand Up @@ -74,6 +75,9 @@ private void rule_apply(Hop hop, boolean descendFirst)
if(hop.isVisited())
return;

//DMLExecutor.println("Hop: " + hop + ", " + hop.getName() + ": " + HopRewriteUtils.isSparse(hop));
//DMLExecutor.println("NNZ: " + hop.getNnz());

//System.out.println("Stepping into: " + hop);

//recursively process children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public static long estimateCost(RewriterStatement stmt, final RuleContext ctx) {
public static long estimateCost(RewriterStatement stmt, Function<RewriterStatement, Long> propertyGenerator, final RuleContext ctx) {
RewriterAssertions assertions = new RewriterAssertions(ctx);
RewriterStatement costFn = propagateCostFunction(stmt, ctx, assertions);
System.out.println(costFn.toParsableString(ctx));

Map<RewriterStatement, RewriterStatement> map = new HashMap<>();

Expand All @@ -42,6 +43,7 @@ public static long estimateCost(RewriterStatement stmt, Function<RewriterStateme
cur.getOperands().set(i, mNew);
} else if (op.isInstruction()) {
if (op.trueInstruction().equals("ncol") || op.trueInstruction().equals("nrow")) {
System.out.println("Generating for: " + op);
mNew = RewriterStatement.literal(ctx, propertyGenerator.apply(op));
map.put(op, mNew);
cur.getOperands().set(i, mNew);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,10 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
for (int idx = 0; idx < 2; idx++) {
RewriterStatement oldRef = lnk.oldStmt.getOperands().get(idx);
RewriterStatement newRef = lnk.newStmt.get(0).getChild(idx);
RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, newRef.getChild(1, 1, 0)).consolidate(ctx);
System.out.println("NewRef: " + newRef.toParsableString(ctx));
RewriterStatement mStmtC = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef.getChild(1, 1, 0), RewriterStatement.literal(ctx, -1L)).consolidate(ctx);
RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, mStmtC).consolidate(ctx);
System.out.println("mStmt: " + mStmt.toParsableString(ctx));
final RewriterStatement newStmt = RewriterUtils.foldConstants(mStmt, ctx);

// Replace all references to h with
Expand All @@ -1049,8 +1052,10 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
RewriterStatement child = el.getOperands().get(i);
Object meta = child.getMeta("idxId");

if (meta instanceof UUID && meta.equals(oldRef.getMeta("idxId")))
if (meta instanceof UUID && meta.equals(oldRef.getMeta("idxId"))) {
System.out.println("NewStmt: " + newStmt.toParsableString(ctx));
el.getOperands().set(i, newStmt);
}
}
}, false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final R
return false; // The program should not be executed as we just want to extract any rewrites that are applied to the current statement
});

DMLExecutor.executeCode(code2);
DMLExecutor.executeCode(code2, true);
RewriterRuntimeUtils.detachHopInterceptor();

return isValid.booleanValue() && isRelevant.booleanValue();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.apache.sysds.test.component.codegen.rewrite.functions;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.rewriter.DMLExecutor;
import org.junit.Test;

public class CodeTest {
@Test
public void test() {
String str = "X = rand(rows=5000, cols=5000, sparsity=0.1)\n" +
"Y = rand(rows=5000, cols=5000, sparsity=0.1)\n" +
"R = X*Y\n" +
"print(lineage(R))";
DMLScript.APPLY_GENERATED_REWRITES = true;
DMLExecutor.executeCode(str, false, "-applyGeneratedRewrites");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class CostEstimates {
@BeforeClass
public static void setup() {
ctx = RewriterUtils.buildDefaultContext();
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false);
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true);
}

@Test
Expand Down Expand Up @@ -217,4 +217,55 @@ public void test11() {
System.out.println(stmt2.toParsableString(ctx, true));
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void test12() {
String stmtStr1 = "MATRIX:A,B\n" +
"LITERAL_INT:1\n" +
"+([](A, 1, nrow(A), 1, 1),B)";
String stmtStr2 = "MATRIX:A,B\n" +
"LITERAL_INT:1\n" +
"+([](A, 1, nrow(A), 1, ncol(A)), B)";

RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx);
RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx);

long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx);
long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx);
System.out.println("Cost1: " + cost1);
System.out.println("Cost2: " + cost2);
System.out.println("Ratio: " + ((double)cost1)/cost2);

assert cost1 < cost2;
}

@Test
public void test13() {
String stmtStr1 = "MATRIX:A,B\n" +
"LITERAL_INT:1\n" +
"[](rowSums(A), 1, nrow(A), 1, 1)";
String stmtStr2 = "MATRIX:A,B\n" +
"LITERAL_INT:1\n" +
"rowSums(A)";

RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx);
RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx);

long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx);
long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx);
System.out.println("Cost1: " + cost1);
System.out.println("Cost2: " + cost2);
System.out.println("Ratio: " + ((double)cost1)/cost2);

assert cost2 < cost1;

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}
}

0 comments on commit 6411b5a

Please # to comment.