forked from apache/systemds
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
142 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package org.apache.sysds.hops.rewriter; | ||
|
||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
public class RewriterRuleCreator { | ||
public static RewriterRule createRule(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) { | ||
from = from.nestedCopy(); | ||
to = to.nestedCopy(); | ||
Map<RewriterStatement, RewriterStatement> assocs = getAssociations(from, to, canonicalForm1, canonicalForm2, ctx); | ||
|
||
// Now, we replace all variables with a common element | ||
from.forEachPreOrder((cur, parent, pIdx) -> { | ||
for (int i = 0; i < cur.getOperands().size(); i++) { | ||
RewriterStatement child = cur.getChild(i); | ||
|
||
if (child instanceof RewriterDataType && !child.isLiteral()) { | ||
RewriterStatement newRef = assocs.get(cur.getChild(i)); | ||
|
||
if (newRef == null) | ||
throw new IllegalArgumentException(); | ||
|
||
cur.getOperands().set(i, newRef); | ||
} | ||
} | ||
|
||
return true; | ||
}); | ||
|
||
RewriterRule rule = new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build(); | ||
return rule; | ||
} | ||
|
||
private static Map<RewriterStatement, RewriterStatement> getAssociations(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalFormFrom, RewriterStatement canonicalFormTo, final RuleContext ctx) { | ||
Map<RewriterStatement, RewriterStatement> fromCanonicalLink = getAssociationToCanonicalForm(from, canonicalFormFrom, true); | ||
Map<RewriterStatement, RewriterStatement> toCanonicalLink = getAssociationToCanonicalForm(to, canonicalFormTo, true); | ||
|
||
RewriterStatement.MatcherContext matcher = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalFormTo); | ||
canonicalFormFrom.match(matcher); | ||
|
||
Map<RewriterStatement, RewriterStatement> assocs = new HashMap<>(); | ||
matcher.getDependencyMap().forEach((k, v) -> { | ||
if (k.isLiteral()) | ||
return; | ||
|
||
RewriterStatement newKey = fromCanonicalLink.get(k); | ||
RewriterStatement newValue = toCanonicalLink.get(v); | ||
|
||
if (newKey == null || newValue == null) | ||
throw new IllegalArgumentException("Null reference detected!"); | ||
|
||
assocs.put(newKey, newValue); | ||
}); | ||
|
||
return assocs; | ||
} | ||
|
||
private static Map<RewriterStatement, RewriterStatement> getAssociationToCanonicalForm(RewriterStatement stmt, RewriterStatement canonicalForm, boolean reversed) { | ||
// We identify all associations by their names | ||
// If there are name collisions, this does not work | ||
Map<String, RewriterStatement> namedVariables = new HashMap<>(); | ||
stmt.forEachPostOrder((cur, parent, pIdx) -> { | ||
if (!(cur instanceof RewriterDataType) || cur.isLiteral()) | ||
return; | ||
|
||
if (namedVariables.put(cur.getId(), cur) != null) | ||
throw new IllegalArgumentException("Duplicate variable name: " + cur.toParsableString(RuleContext.currentContext)); | ||
}); | ||
|
||
Map<RewriterStatement, RewriterStatement> assoc = new HashMap<>(); | ||
|
||
canonicalForm.forEachPostOrder((cur, parent, pIdx) -> { | ||
if (!(cur instanceof RewriterDataType) || cur.isLiteral()) | ||
return; | ||
|
||
RewriterStatement ref = namedVariables.get(cur.getId()); | ||
|
||
if (ref == null) | ||
throw new IllegalArgumentException("Unknown variable reference name '" + cur.getId() + "' in: " + cur.toParsableString(RuleContext.currentContext)); | ||
|
||
if (reversed) | ||
assoc.put(cur, ref); | ||
else | ||
assoc.put(ref, cur); | ||
}); | ||
|
||
return assoc; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
...est/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package org.apache.sysds.test.component.codegen.rewrite.functions; | ||
|
||
import org.apache.sysds.hops.rewriter.RewriterRule; | ||
import org.apache.sysds.hops.rewriter.RewriterRuleCreator; | ||
import org.apache.sysds.hops.rewriter.RewriterStatement; | ||
import org.apache.sysds.hops.rewriter.RewriterUtils; | ||
import org.apache.sysds.hops.rewriter.RuleContext; | ||
import org.junit.BeforeClass; | ||
import org.junit.Test; | ||
|
||
import java.util.function.Function; | ||
|
||
public class RuleCreationTests { | ||
private static RuleContext ctx; | ||
private static Function<RewriterStatement, RewriterStatement> canonicalConverter; | ||
|
||
@BeforeClass | ||
public static void setup() { | ||
ctx = RewriterUtils.buildDefaultContext(); | ||
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); | ||
} | ||
|
||
@Test | ||
public void test1() { | ||
RewriterStatement from = RewriterUtils.parse("t(%*%(t(U),V))", ctx, "MATRIX:U,V"); | ||
RewriterStatement to = RewriterUtils.parse("%*%(t(U), V)", ctx, "MATRIX:U,V"); | ||
RewriterStatement canonicalForm1 = canonicalConverter.apply(from); | ||
RewriterStatement canonicalForm2 = canonicalConverter.apply(to); | ||
|
||
System.out.println("=========="); | ||
System.out.println(canonicalForm1.toParsableString(ctx, true)); | ||
System.out.println("=========="); | ||
System.out.println(canonicalForm2.toParsableString(ctx, true)); | ||
assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2)); | ||
|
||
RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); | ||
System.out.println(rule); | ||
} | ||
} |