Skip to content

Commit

Permalink
Cost estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 7f0e47e commit e63f5e9
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,18 @@ public static String getDefaultContextString() {
builder.append("diag(MATRIX)::MATRIX\n");

List.of("INT", "FLOAT", "BOOL").forEach(t -> {
builder.append("sum(" + t + "...)::" + t + "\n");
builder.append("sum(" + t + "*)::" + t + "\n");
builder.append("sum(" + t + ")::" + t + "\n");
String newType = t.equals("BOOL") ? "INT" : t;
builder.append("sum(" + t + "...)::" + newType + "\n");
builder.append("sum(" + t + "*)::" + newType + "\n");
builder.append("sum(" + t + ")::" + newType + "\n");

builder.append("min(" + t + "...)::" + t + "\n");
builder.append("min(" + t + "*)::" + t + "\n");
builder.append("min(" + t + ")::" + t + "\n");

builder.append("max(" + t + "...)::" + t + "\n");
builder.append("max(" + t + "*)::" + t + "\n");
builder.append("max(" + t + ")::" + t + "\n");
});

builder.append("_m(INT,INT,FLOAT)::MATRIX\n");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.apache.sysds.hops.rewriter;

import org.apache.commons.lang3.mutable.MutableLong;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -8,13 +10,13 @@
import java.util.function.Function;

public class RewriterCostEstimator {
private static final long INSTRUCTION_OVERHEAD = 10;
private static final long MALLOC_COST = 10000;

public static long estimateCost(RewriterStatement stmt, Function<RewriterStatement, Long> propertyGenerator, final RuleContext ctx) {
RewriterAssertions assertions = new RewriterAssertions(ctx);
RewriterStatement costFn = propagateCostFunction(stmt, ctx, assertions);

// Now, assign
System.out.println(costFn);

Map<RewriterStatement, RewriterStatement> map = new HashMap<>();

costFn.forEachPostOrder((cur, parent, pIdx) -> {
Expand Down Expand Up @@ -47,14 +49,18 @@ public static long estimateCost(RewriterStatement stmt, Function<RewriterStateme

private static RewriterStatement propagateCostFunction(RewriterStatement stmt, final RuleContext ctx, RewriterAssertions assertions) {
List<RewriterStatement> includedCosts = new ArrayList<>();
MutableLong instructionOverhead = new MutableLong(0);

stmt.forEachPostOrder((cur, parent, pIdx) -> {
if (!(cur instanceof RewriterInstruction))
return;

computeCostOf((RewriterInstruction) cur, ctx, includedCosts, assertions);
computeCostOf((RewriterInstruction) cur, ctx, includedCosts, assertions, instructionOverhead);
instructionOverhead.add(INSTRUCTION_OVERHEAD);
});

includedCosts.add(RewriterStatement.literal(ctx, instructionOverhead.longValue()));

RewriterStatement argList = RewriterStatement.argList(ctx, includedCosts);
RewriterStatement add = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(argList).consolidate(ctx);
add.unsafePutMeta("_assertions", assertions);
Expand All @@ -63,14 +69,14 @@ private static RewriterStatement propagateCostFunction(RewriterStatement stmt, f
return add;
}

private static void computeCostOf(RewriterInstruction instr, final RuleContext ctx, List<RewriterStatement> uniqueCosts, RewriterAssertions assertions) {
private static void computeCostOf(RewriterInstruction instr, final RuleContext ctx, List<RewriterStatement> uniqueCosts, RewriterAssertions assertions, MutableLong instructionOverhead) {
if (instr.getResultingDataType(ctx).equals("MATRIX"))
computeMatrixOpCost(instr, ctx, uniqueCosts, assertions);
computeMatrixOpCost(instr, ctx, uniqueCosts, assertions, instructionOverhead);
else
computeScalarOpCost(instr, ctx, uniqueCosts, assertions);
computeScalarOpCost(instr, ctx, uniqueCosts, assertions, instructionOverhead);
}

private static void computeMatrixOpCost(RewriterInstruction instr, final RuleContext ctx, List<RewriterStatement> uniqueCosts, RewriterAssertions assertions) {
private static void computeMatrixOpCost(RewriterInstruction instr, final RuleContext ctx, List<RewriterStatement> uniqueCosts, RewriterAssertions assertions, MutableLong overhead) {
RewriterStatement cost = null;
Map<String, RewriterStatement> map = new HashMap<>();

Expand All @@ -85,12 +91,53 @@ private static void computeMatrixOpCost(RewriterInstruction instr, final RuleCon
// Rough estimation
cost = RewriterUtils.parse("*(argList(nrowA, ncolA, ncolB, +(argList(mulCost, sumCost))))", ctx, map);
assertions.addEqualityAssertion(map.get("ncolA"), map.get("nrowB"));
overhead.add(MALLOC_COST);
break;
case "t":
case "rowSums":
case "colSums":
map.put("nrowA", instr.getChild(0).getNRow());
map.put("ncolA", instr.getChild(0).getNCol());
// Rough estimation
cost = RewriterUtils.parse("*(argList(nrowA, ncolA))", ctx, map);
overhead.add(MALLOC_COST);
break;
case "diag":
map.put("nrowA", instr.getChild(0).getNRow());
map.put("ncolA", instr.getChild(0).getNCol());
cost = map.get("nrowA");
assertions.addEqualityAssertion(map.get("nrowA"), map.get("ncolA"));
overhead.add(MALLOC_COST);
break;
case "cast.MATRIX":
cost = RewriterStatement.literal(ctx, 5L);
break;
case "[]":
break; // I assume that nothing is materialized
case "RBind":
map.put("nrowA", instr.getChild(0).getNRow());
map.put("ncolA", instr.getChild(0).getNCol());
map.put("nrowA", instr.getChild(1).getNRow());
map.put("ncolA", instr.getChild(1).getNCol());
cost = map.get("+(argList(*(argList(nrowA, ncolA)), *(argList(nrowB, ncolB))))");
assertions.addEqualityAssertion(instr.getChild(0).getNCol(), instr.getChild(1).getNCol());
overhead.add(MALLOC_COST);
break;
case "CBind":
map.put("nrowA", instr.getChild(0).getNRow());
map.put("ncolA", instr.getChild(0).getNCol());
map.put("nrowA", instr.getChild(1).getNRow());
map.put("ncolA", instr.getChild(1).getNCol());
cost = map.get("+(argList(*(argList(nrowA, ncolA)), *(argList(nrowB, ncolB))))");
assertions.addEqualityAssertion(instr.getChild(0).getNRow(), instr.getChild(1).getNRow());
overhead.add(MALLOC_COST);
break;
case "rand":
map.put("nrowA", instr.getChild(0).getNRow());
map.put("ncolA", instr.getChild(0).getNCol());
cost = map.get("*(argList(nrowA, ncolA))");
overhead.add(MALLOC_COST);
break;
}

if (cost == null) {
Expand All @@ -101,6 +148,7 @@ private static void computeMatrixOpCost(RewriterInstruction instr, final RuleCon
.withOps(RewriterStatement.argList(ctx, opCost, instr.getNCol(), instr.getNRow()));
assertions.addEqualityAssertion(instr.getChild(0).getNCol(), instr.getChild(1).getNCol());
assertions.addEqualityAssertion(instr.getChild(0).getNRow(), instr.getChild(1).getNRow());
overhead.add(MALLOC_COST);
} else {
throw new IllegalArgumentException();
}
Expand All @@ -109,11 +157,27 @@ private static void computeMatrixOpCost(RewriterInstruction instr, final RuleCon
uniqueCosts.add(cost);
}

private static void computeScalarOpCost(RewriterInstruction instr, final RuleContext ctx, List<RewriterStatement> uniqueCosts, RewriterAssertions assertions) {
RewriterStatement cost = null;
long opCost = atomicOpCost(instr.trueInstruction());

private static void computeScalarOpCost(RewriterInstruction instr, final RuleContext ctx, List<RewriterStatement> uniqueCosts, RewriterAssertions assertions, MutableLong overhead) {
Map<String, RewriterStatement> map = new HashMap<>();
switch (instr.trueTypedInstruction(ctx)) {
case "sum(MATRIX)":
case "min(MATRIX)":
case "max(MATRIX)":
map.put("nrowA", instr.getChild(0).getNRow());
map.put("ncolA", instr.getChild(0).getNCol());
uniqueCosts.add(RewriterUtils.parse("*(argList(nrowA, ncolA))", ctx, map));
return;
case "trace(MATRIX)":
map.put("nrowA", instr.getChild(0).getNRow());
map.put("ncolA", instr.getChild(0).getNCol());
uniqueCosts.add(map.get("nrowA"));
assertions.addEqualityAssertion(map.get("nrowA"), map.get("ncolA"));
return;
case "[](MATRIX,INT,INT)":
return;
}

long opCost = atomicOpCost(instr.trueInstruction());
uniqueCosts.add(RewriterUtils.parse(Long.toString(opCost), ctx, "LITERAL_INT:" + opCost));
}

Expand All @@ -122,10 +186,6 @@ private static RewriterStatement atomicOpCostStmt(String op, final RuleContext c
return RewriterUtils.parse(Long.toString(opCost), ctx, "LITERAL_INT:" + opCost);
}

private static RewriterStatement literalInt(long value) {
return new RewriterDataType().as(Long.toString(value)).ofType("INT").asLiteral(value);
}

private static long atomicOpCost(String op) {
switch (op) {
case "+":
Expand All @@ -136,6 +196,27 @@ private static long atomicOpCost(String op) {
case "/":
case "inv":
return 3;
case "length":
case "nrow":
case "ncol":
return 0; // These just fetch metadata
case "sqrt":
return 10;
case "exp":
case "^":
return 20;
case "!":
case "|":
case "&":
case ">":
case ">=":
case "<":
case "<=":
return 1;
case "round":
return 2;
case "abs":
return 2;
}

throw new IllegalArgumentException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,19 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
}, true)
.build()
);

// cast.MATRIX(a) => _m(1, 1, a)
for (String t : List.of("INT", "BOOL", "FLOAT")) {
rules.add(new RewriterRuleBuilder(ctx)
.setUnidirectional(true)
.parseGlobalVars(t + ":a")
.parseGlobalVars("LITERAL_INT:1")
.withParsedStatement("cast.MATRIX(a)", hooks)
.toParsedStatement("$2:_m(1, 1, a)", hooks)
.apply(hooks.get(2).getId(), (stmt, match) -> stmt.unsafePutMeta("ownerId", UUID.randomUUID()), true)
.build()
);
}
}

public static void expandArbitraryMatrices(final List<RewriterRule> rules, final RuleContext ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,4 +722,20 @@ public void testConstantFolding4() {

assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void testAdvancedEquivalence1() {
RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, -7))", ctx, "MATRIX:A", "LITERAL_FLOAT:-7");
RewriterStatement stmt2 = RewriterUtils.parse("sum(-(A, 7))", ctx, "MATRIX:A", "LITERAL_FLOAT:7");

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));

assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import org.junit.BeforeClass;
import org.junit.Test;

import java.util.Random;
import java.util.function.Function;

public class CostEstimates {
Expand All @@ -17,7 +16,7 @@ public class CostEstimates {
@BeforeClass
public static void setup() {
ctx = RewriterUtils.buildDefaultContext();
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true);
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false);
}

@Test
Expand Down Expand Up @@ -47,6 +46,82 @@ public void test4() {
long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx);
System.out.println("Cost1: " + cost1);
System.out.println("Cost2: " + cost2);
System.out.println("Ratio: " + ((double)cost1)/cost2);
assert cost1 < cost2;
}

@Test
public void test5() {
RewriterStatement stmt1 = RewriterUtils.parse("t(/(*(A, B), C))", ctx, "MATRIX:A,B,C");
RewriterStatement stmt2 = RewriterUtils.parse("/(*(t(A), t(B)), t(C))", ctx, "MATRIX:A,B,C");
long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx);
long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx);
System.out.println("Cost1: " + cost1);
System.out.println("Cost2: " + cost2);
System.out.println("Ratio: " + ((double)cost1)/cost2);
assert cost1 < cost2;

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void test6() {
RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, B))", ctx, "MATRIX:A,B,C");
RewriterStatement stmt2 = RewriterUtils.parse("+(sum(A), sum(B))", ctx, "MATRIX:A,B,C");
long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx);
long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx);
System.out.println("Cost1: " + cost1);
System.out.println("Cost2: " + cost2);
System.out.println("Ratio: " + ((double)cost2)/cost1);
assert cost2 < cost1;

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void test7() {
RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(A))", ctx, "MATRIX:A,B,C");
RewriterStatement stmt2 = RewriterUtils.parse("rowSums(colSums(A))", ctx, "MATRIX:A,B,C");
long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx);
long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx);
System.out.println("Cost1: " + cost1);
System.out.println("Cost2: " + cost2);
System.out.println("Ratio: " + ((double)cost1)/cost2);
assert cost1 < cost2;

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void test8() {
RewriterStatement stmt1 = RewriterUtils.parse("sum(*(diag(A), diag(B)))", ctx, "MATRIX:A,B,C");
RewriterStatement stmt2 = RewriterUtils.parse("trace(*(A, B))", ctx, "MATRIX:A,B,C");

long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx);
long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx);
System.out.println("Cost1: " + cost1);
System.out.println("Cost2: " + cost2);
System.out.println("Ratio: " + ((double)cost1)/cost2);
assert cost1 < cost2;

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}
}

0 comments on commit e63f5e9

Please # to comment.