Skip to content

Commit

Permalink
Some improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 0d5651f commit bfcbe44
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ public class RewriterAlphabetEncoder {
new Operand("colSums", 1, MATRIX),
new Operand("max", 1, MATRIX),
new Operand("min", 1, MATRIX),
new Operand("fncol", 1, true, MATRIX),
new Operand("ncol", 1, true, MATRIX),
new Operand("nrow", 1, true, MATRIX),
new Operand("length", 1, true, MATRIX),
/*new Operand("fncol", 1, true, MATRIX),
new Operand("fnrow", 1, true, MATRIX),
new Operand("flength", 1, true, MATRIX),
new Operand("flength", 1, true, MATRIX),*/

new Operand("!=", 2, ALL_TYPES, ALL_TYPES),
new Operand("!=0", 1, MATRIX),
Expand All @@ -76,9 +79,9 @@ public class RewriterAlphabetEncoder {
new Operand("c_-1", 1, ALL_TYPES),

// ncol / nrow / length stuff
new Operand("c_flength*", 1, ALL_TYPES),
new Operand("c_fncol*", 1, ALL_TYPES),
new Operand("c_fnrow*", 1, ALL_TYPES),
new Operand("c_length*", 2, MATRIX, ALL_TYPES),
new Operand("c_ncol*", 2, MATRIX, ALL_TYPES),
new Operand("c_nrow*", 2, MATRIX, ALL_TYPES),

//new Operand("log_nz", 1, MATRIX), // TODO: We have to include literals in the search

Expand Down Expand Up @@ -365,6 +368,13 @@ private static RewriterStatement buildStmt(Operand op, RewriterStatement[] stack
stmt.withInstruction("!=").addOp(RewriterStatement.literal(ctx, 0.0D)).addOp(stack[0]);
break;
}
case "ncol":
case "nrow":
case "length": {
String actualOp = op.op.substring(1);
stmt.withInstruction(actualOp).withOps(stack).consolidate(ctx);
break;
}
case "fncol":
case "fnrow":
case "flength": {
Expand All @@ -378,8 +388,37 @@ private static RewriterStatement buildStmt(Operand op, RewriterStatement[] stack
stmt = new RewriterInstruction("*", ctx, old, stack[1]);
break;
}
case "c_1+": {
stmt = new RewriterInstruction("+", ctx, RewriterStatement.literal(ctx, 1.0D), stack[0]);
break;
}
case "c_+1": {
stmt = new RewriterInstruction("+", ctx, stack[0], RewriterStatement.literal(ctx, 1.0D));
break;
}
case "c_1-": {
stmt = new RewriterInstruction("-", ctx, RewriterStatement.literal(ctx, 1.0D), stack[0]);
break;
}
case "c_-1": {
stmt = new RewriterInstruction("-", ctx, stack[0], RewriterStatement.literal(ctx, 1.0D));
break;
}
case "c_length*": {
stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("length", ctx, stack[0]), stack[1]);
break;
}
case "c_nrow*": {
stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("nrow", ctx, stack[0]), stack[1]);
break;
}
case "c_col*": {
stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("ncol", ctx, stack[0]), stack[1]);
break;
}
default: {
stmt.withInstruction(op.op).withOps(stack);
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,6 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
}

// This expands the statements to a common canonical form
// It is important, however, that
public static void canonicalExpandAfterFlattening(final List<RewriterRule> rules, final RuleContext ctx) {
HashMap<Integer, RewriterStatement> hooks = new HashMap<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.sysds.hops.rewriter.dml.DMLExecutor;
import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator;
import org.apache.sysds.hops.rewriter.utils.RewriterUtils;
import scala.Tuple2;

import javax.annotation.Nullable;
import java.util.ArrayList;
Expand Down Expand Up @@ -377,8 +378,20 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
}

public static RewriterRule createRule(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) {
Tuple2<RewriterStatement, RewriterStatement> commonForm = createCommonForm(from, to, canonicalForm1, canonicalForm2, ctx);
from = commonForm._1;
to = commonForm._2;

return new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build();
}

public static RewriterRule createRuleFromCommonStatements(RewriterStatement from, RewriterStatement to, final RuleContext ctx) {
return new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build();
}

public static Tuple2<RewriterStatement, RewriterStatement> createCommonForm(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) {
from = from.nestedCopy(true);
to = to.nestedCopy(true);
//to = to.nestedCopy(true);
Map<RewriterStatement, RewriterStatement> assocs = getAssociations(from, to, canonicalForm1, canonicalForm2, ctx);

// Now, we replace all variables with a common element
Expand All @@ -401,10 +414,8 @@ public static RewriterRule createRule(RewriterStatement from, RewriterStatement
}, false);

from = ctx.metaPropagator.apply(from);
to = ctx.metaPropagator.apply(to);

RewriterRule rule = new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build();
return rule;
//to = ctx.metaPropagator.apply(to);
return new Tuple2<>(from, to);
}

private static Map<RewriterStatement, RewriterStatement> getAssociations(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalFormFrom, RewriterStatement canonicalFormTo, final RuleContext ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,7 @@ public static Function<RewriterStatement, RewriterStatement> unfuseOperators(fin
private static RuleContext lastSparsityCtx;
private static Function<RewriterStatement, RewriterStatement> lastPrepareForSparsity;

@Deprecated
public static Function<RewriterStatement, RewriterStatement> prepareForSparsityEstimation(final RuleContext ctx) {
if (lastSparsityCtx == ctx)
return lastPrepareForSparsity;
Expand Down Expand Up @@ -1455,8 +1456,8 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule

if (!checkSubgraphDependency(sumBody, ownerId, checked)) {
// Then we have to remove the sum entirely
RewriterStatement negation = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(RewriterStatement.ensureFloat(ctx, idxFrom)).consolidate(ctx);
RewriterStatement add = RewriterStatement.multiArgInstr(ctx, "+", RewriterStatement.ensureFloat(ctx, idxTo), negation);
RewriterStatement negation = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(/*RewriterStatement.ensureFloat(ctx, idxFrom)*/idxFrom).consolidate(ctx);
RewriterStatement add = RewriterStatement.multiArgInstr(ctx, "+", /*RewriterStatement.ensureFloat(ctx, idxTo)*/idxTo, negation);
add = foldConstants(add, ctx);
return RewriterStatement.multiArgInstr(ctx, "*", sumBody, add);
}
Expand Down Expand Up @@ -1510,8 +1511,8 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule
List<RewriterStatement> mul = new ArrayList<>();

for (RewriterStatement idx : idxExpr.getChild(0).getOperands()) {
RewriterStatement neg = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(RewriterStatement.ensureFloat(ctx, idx.getChild(0))).consolidate(ctx);
RewriterStatement msum = RewriterStatement.multiArgInstr(ctx, "+", RewriterStatement.ensureFloat(ctx, idx.getChild(1)), neg, RewriterStatement.literal(ctx, 1.0));
RewriterStatement neg = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(/*RewriterStatement.ensureFloat(ctx, idx.getChild(0))*/idx.getChild(0)).consolidate(ctx);
RewriterStatement msum = RewriterStatement.multiArgInstr(ctx, "+", /*RewriterStatement.ensureFloat(ctx, idx.getChild(1))*/idx.getChild(1), neg, RewriterStatement.literal(ctx, 1.0));
mul.add(msum);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public static void setup() {

public static void testExpressionClustering() {
boolean useData = false;
boolean useRandomized = true;
boolean useSystematic = true;
int systematicSearchDepth = 2;
boolean useRandomLarge = false;

long startTime = System.currentTimeMillis();
Expand Down Expand Up @@ -159,10 +160,10 @@ public static void testExpressionClustering() {
db = null;
Object lock = new Object();

if (useRandomized) {
if (useSystematic) {
long MAX_MILLIS = 1200000; // Should be bound by number of ops
int BATCH_SIZE = 400;
int maxN = RewriterAlphabetEncoder.getMaxSearchNumberForNumOps(2);
int maxN = RewriterAlphabetEncoder.getMaxSearchNumberForNumOps(systematicSearchDepth);
System.out.println("MaxN: " + maxN);
long startMillis = System.currentTimeMillis();

Expand Down Expand Up @@ -198,6 +199,12 @@ public static void testExpressionClustering() {
synchronized (lock) {
RewriterEquivalenceDatabase.DBEntry entry = canonicalExprDB.insert(ctx, canonicalForm, stmt);

// Now, we use common variables
if (entry.equivalences.size() > 1) {
RewriterStatement commonForm = RewriterRuleCreator.createCommonForm(entry.equivalences.get(0), stmt, entry.canonicalForm, canonicalForm, ctx)._2;
entry.equivalences.set(entry.equivalences.size()-1, commonForm);
}

if (entry.equivalences.size() == 2)
foundEquivalences.add(entry);
}
Expand Down Expand Up @@ -248,6 +255,12 @@ public static void testExpressionClustering() {
synchronized (lock) {
RewriterEquivalenceDatabase.DBEntry entry = canonicalExprDB.insert(ctx, canonicalForm, stmt);

// Now, we use common variables
if (entry.equivalences.size() > 1) {
RewriterStatement commonForm = RewriterRuleCreator.createCommonForm(entry.equivalences.get(0), stmt, entry.canonicalForm, canonicalForm, ctx)._2;
entry.equivalences.set(entry.equivalences.size()-1, commonForm);
}

if (entry.equivalences.size() == 2)
foundEquivalences.add(entry);
}
Expand Down Expand Up @@ -279,15 +292,10 @@ public static void testExpressionClustering() {
if (++mCtr % 100 == 0)
System.out.println("Creating rule: " + mCtr + " / " + rewrites.size());

ctx.metaPropagator.apply(rewrite._4());
ctx.metaPropagator.apply(rewrite._5());
RewriterStatement canonicalFormFrom = converter.apply(rewrite._4());
RewriterStatement canonicalFormTo = converter.apply(rewrite._5());
try {
RewriterRule rule = RewriterRuleCreator.createRule(rewrite._4(), rewrite._5(), canonicalFormFrom, canonicalFormTo, ctx);
RewriterRule rule = RewriterRuleCreator.createRuleFromCommonStatements(rewrite._4(), rewrite._5(), ctx);

allRules.add(new Tuple4<>(rule, rewrite._2(), rewrite._3(), rule.getStmt1().countInstructions()));
//ruleCreator.registerRule(rule, rewrite._2(), rewrite._3());
} catch (Exception e) {
System.err.println("An error occurred while trying to create a rule:");
System.err.println(rewrite._4().toParsableString(ctx, true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1602,4 +1602,36 @@ public void testWrong5() {

assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
}

@Test
public void testConstInequivality() {
RewriterStatement stmt1 = RewriterUtils.parse("%*%(const(A, 0.0), A)", ctx, "MATRIX:A", "LITERAL_FLOAT:0.0");
RewriterStatement stmt2 = RewriterUtils.parse("const(A, 0.0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0.0");

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));

assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
}

@Test
public void testSumEquality7() {
RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a, A))", ctx, "MATRIX:A", "FLOAT:a");
RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(A))", ctx, "MATRIX:A", "FLOAT:a");

stmt1 = canonicalConverter.apply(stmt1);
stmt2 = canonicalConverter.apply(stmt2);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(stmt2.toParsableString(ctx, true));

assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
}
}

0 comments on commit bfcbe44

Please # to comment.