From 4990f9bd2910f7ad61a4cd3d7438002709432b65 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Thu, 7 Nov 2024 12:43:33 +0100 Subject: [PATCH] Rule generator --- .../sysds/hops/rewriter/RewriterRule.java | 4 +- .../hops/rewriter/RewriterRuleBuilder.java | 11 +++ .../hops/rewriter/RewriterRuleCreator.java | 89 +++++++++++++++++++ .../hops/rewriter/RewriterStatement.java | 1 + .../rewrite/functions/RuleCreationTests.java | 39 ++++++++ 5 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java index f8182311067..bc61370ee4e 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java @@ -331,9 +331,9 @@ private RewriterStatement applyInplace(RewriterStatement.MatchingSubexpression m public String toString() { if (isUnidirectional()) - return fromRoot.toString() + " => " + toRoot.toString(); + return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx); else - return fromRoot.toString() + " <=> " + toRoot.toString(); + return fromRoot.toParsableString(ctx) + " <=> " + toRoot.toParsableString(ctx); } // TODO: Rework diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleBuilder.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleBuilder.java index 2da495af5ba..0a97b83b52b 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleBuilder.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleBuilder.java @@ -235,6 +235,17 @@ public RewriterRuleBuilder withInstruction(String instr) { return this; } + public RewriterRuleBuilder completeRule(RewriterStatement from, RewriterStatement to) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + this.fromRoot = from; + this.toRoot = to; + this.mappingState = true; + return this; + } + public RewriterRuleBuilder withOps(RewriterDataType... operands) { if (!canBeModified) throw new IllegalArgumentException("The DAG is final and cannot be modified"); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java new file mode 100644 index 00000000000..fd633c51ed8 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java @@ -0,0 +1,89 @@ +package org.apache.sysds.hops.rewriter; + +import java.util.HashMap; +import java.util.Map; + +public class RewriterRuleCreator { + public static RewriterRule createRule(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) { + from = from.nestedCopy(); + to = to.nestedCopy(); + Map assocs = getAssociations(from, to, canonicalForm1, canonicalForm2, ctx); + + // Now, we replace all variables with a common element + from.forEachPreOrder((cur, parent, pIdx) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (child instanceof RewriterDataType && !child.isLiteral()) { + RewriterStatement newRef = assocs.get(cur.getChild(i)); + + if (newRef == null) + throw new IllegalArgumentException(); + + cur.getOperands().set(i, newRef); + } + } + + return true; + }); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build(); + return rule; + } + + private static Map getAssociations(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalFormFrom, RewriterStatement canonicalFormTo, final RuleContext ctx) { + Map fromCanonicalLink = getAssociationToCanonicalForm(from, canonicalFormFrom, true); + Map toCanonicalLink = getAssociationToCanonicalForm(to, canonicalFormTo, true); + + RewriterStatement.MatcherContext matcher = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalFormTo); + canonicalFormFrom.match(matcher); + + Map assocs = new HashMap<>(); + matcher.getDependencyMap().forEach((k, v) -> { + if (k.isLiteral()) + return; + + RewriterStatement newKey = fromCanonicalLink.get(k); + RewriterStatement newValue = toCanonicalLink.get(v); + + if (newKey == null || newValue == null) + throw new IllegalArgumentException("Null reference detected!"); + + assocs.put(newKey, newValue); + }); + + return assocs; + } + + private static Map getAssociationToCanonicalForm(RewriterStatement stmt, RewriterStatement canonicalForm, boolean reversed) { + // We identify all associations by their names + // If there are name collisions, this does not work + Map namedVariables = new HashMap<>(); + stmt.forEachPostOrder((cur, parent, pIdx) -> { + if (!(cur instanceof RewriterDataType) || cur.isLiteral()) + return; + + if (namedVariables.put(cur.getId(), cur) != null) + throw new IllegalArgumentException("Duplicate variable name: " + cur.toParsableString(RuleContext.currentContext)); + }); + + Map assoc = new HashMap<>(); + + canonicalForm.forEachPostOrder((cur, parent, pIdx) -> { + if (!(cur instanceof RewriterDataType) || cur.isLiteral()) + return; + + RewriterStatement ref = namedVariables.get(cur.getId()); + + if (ref == null) + throw new IllegalArgumentException("Unknown variable reference name '" + cur.getId() + "' in: " + cur.toParsableString(RuleContext.currentContext)); + + if (reversed) + assoc.put(cur, ref); + else + assoc.put(ref, cur); + }); + + return assoc; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java index 3748110a485..ed2effbf718 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java @@ -386,6 +386,7 @@ public RewriterStatement nestedCopyOrInject(Map injector.apply(el), null, -1); } + // TODO: This does not copy the associations if they exist public RewriterStatement nestedCopy() { return nestedCopyOrInject(new HashMap<>(), el -> null); } diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java new file mode 100644 index 00000000000..919906f3afd --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java @@ -0,0 +1,39 @@ +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.sysds.hops.rewriter.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.function.Function; + +public class RuleCreationTests { + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + RewriterStatement from = RewriterUtils.parse("t(%*%(t(U),V))", ctx, "MATRIX:U,V"); + RewriterStatement to = RewriterUtils.parse("%*%(t(U), V)", ctx, "MATRIX:U,V"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + System.out.println("=========="); + System.out.println(canonicalForm1.toParsableString(ctx, true)); + System.out.println("=========="); + System.out.println(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + System.out.println(rule); + } +}