Skip to content

Commit

Permalink
Rule generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 0286e48 commit 4990f9b
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ private RewriterStatement applyInplace(RewriterStatement.MatchingSubexpression m

public String toString() {
if (isUnidirectional())
return fromRoot.toString() + " => " + toRoot.toString();
return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx);
else
return fromRoot.toString() + " <=> " + toRoot.toString();
return fromRoot.toParsableString(ctx) + " <=> " + toRoot.toParsableString(ctx);
}

// TODO: Rework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ public RewriterRuleBuilder withInstruction(String instr) {
return this;
}

public RewriterRuleBuilder completeRule(RewriterStatement from, RewriterStatement to) {
if (!canBeModified)
throw new IllegalArgumentException("The DAG is final and cannot be modified");
if (mappingState)
throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined");
this.fromRoot = from;
this.toRoot = to;
this.mappingState = true;
return this;
}

public RewriterRuleBuilder withOps(RewriterDataType... operands) {
if (!canBeModified)
throw new IllegalArgumentException("The DAG is final and cannot be modified");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.apache.sysds.hops.rewriter;

import java.util.HashMap;
import java.util.Map;

public class RewriterRuleCreator {
public static RewriterRule createRule(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) {
from = from.nestedCopy();
to = to.nestedCopy();
Map<RewriterStatement, RewriterStatement> assocs = getAssociations(from, to, canonicalForm1, canonicalForm2, ctx);

// Now, we replace all variables with a common element
from.forEachPreOrder((cur, parent, pIdx) -> {
for (int i = 0; i < cur.getOperands().size(); i++) {
RewriterStatement child = cur.getChild(i);

if (child instanceof RewriterDataType && !child.isLiteral()) {
RewriterStatement newRef = assocs.get(cur.getChild(i));

if (newRef == null)
throw new IllegalArgumentException();

cur.getOperands().set(i, newRef);
}
}

return true;
});

RewriterRule rule = new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build();
return rule;
}

private static Map<RewriterStatement, RewriterStatement> getAssociations(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalFormFrom, RewriterStatement canonicalFormTo, final RuleContext ctx) {
Map<RewriterStatement, RewriterStatement> fromCanonicalLink = getAssociationToCanonicalForm(from, canonicalFormFrom, true);
Map<RewriterStatement, RewriterStatement> toCanonicalLink = getAssociationToCanonicalForm(to, canonicalFormTo, true);

RewriterStatement.MatcherContext matcher = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalFormTo);
canonicalFormFrom.match(matcher);

Map<RewriterStatement, RewriterStatement> assocs = new HashMap<>();
matcher.getDependencyMap().forEach((k, v) -> {
if (k.isLiteral())
return;

RewriterStatement newKey = fromCanonicalLink.get(k);
RewriterStatement newValue = toCanonicalLink.get(v);

if (newKey == null || newValue == null)
throw new IllegalArgumentException("Null reference detected!");

assocs.put(newKey, newValue);
});

return assocs;
}

private static Map<RewriterStatement, RewriterStatement> getAssociationToCanonicalForm(RewriterStatement stmt, RewriterStatement canonicalForm, boolean reversed) {
// We identify all associations by their names
// If there are name collisions, this does not work
Map<String, RewriterStatement> namedVariables = new HashMap<>();
stmt.forEachPostOrder((cur, parent, pIdx) -> {
if (!(cur instanceof RewriterDataType) || cur.isLiteral())
return;

if (namedVariables.put(cur.getId(), cur) != null)
throw new IllegalArgumentException("Duplicate variable name: " + cur.toParsableString(RuleContext.currentContext));
});

Map<RewriterStatement, RewriterStatement> assoc = new HashMap<>();

canonicalForm.forEachPostOrder((cur, parent, pIdx) -> {
if (!(cur instanceof RewriterDataType) || cur.isLiteral())
return;

RewriterStatement ref = namedVariables.get(cur.getId());

if (ref == null)
throw new IllegalArgumentException("Unknown variable reference name '" + cur.getId() + "' in: " + cur.toParsableString(RuleContext.currentContext));

if (reversed)
assoc.put(cur, ref);
else
assoc.put(ref, cur);
});

return assoc;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ public RewriterStatement nestedCopyOrInject(Map<RewriterStatement, RewriterState
return nestedCopyOrInject(copiedObjects, (el, parent, pIdx) -> injector.apply(el), null, -1);
}

// TODO: This does not copy the associations if they exist
public RewriterStatement nestedCopy() {
return nestedCopyOrInject(new HashMap<>(), el -> null);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.apache.sysds.test.component.codegen.rewrite.functions;

import org.apache.sysds.hops.rewriter.RewriterRule;
import org.apache.sysds.hops.rewriter.RewriterRuleCreator;
import org.apache.sysds.hops.rewriter.RewriterStatement;
import org.apache.sysds.hops.rewriter.RewriterUtils;
import org.apache.sysds.hops.rewriter.RuleContext;
import org.junit.BeforeClass;
import org.junit.Test;

import java.util.function.Function;

public class RuleCreationTests {
private static RuleContext ctx;
private static Function<RewriterStatement, RewriterStatement> canonicalConverter;

@BeforeClass
public static void setup() {
ctx = RewriterUtils.buildDefaultContext();
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false);
}

@Test
public void test1() {
RewriterStatement from = RewriterUtils.parse("t(%*%(t(U),V))", ctx, "MATRIX:U,V");
RewriterStatement to = RewriterUtils.parse("%*%(t(U), V)", ctx, "MATRIX:U,V");
RewriterStatement canonicalForm1 = canonicalConverter.apply(from);
RewriterStatement canonicalForm2 = canonicalConverter.apply(to);

System.out.println("==========");
System.out.println(canonicalForm1.toParsableString(ctx, true));
System.out.println("==========");
System.out.println(canonicalForm2.toParsableString(ctx, true));
assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2));

RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx);
System.out.println(rule);
}
}

0 comments on commit 4990f9b

Please # to comment.