From 549d4566eca5fc8714430504060ad07a5c10cc5b Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 7 Jan 2025 16:44:49 +0100 Subject: [PATCH] Multirule parsing --- .../sysds/hops/rewriter/RewriterRule.java | 5 +- .../hops/rewriter/utils/RewriterUtils.java | 54 +++++++++++++++++++ .../functions/RuleSerializationTest.java | 16 ++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java index ce122975338..bd94d01d18d 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java @@ -458,7 +458,10 @@ private RewriterStatement applyInplace(RewriterStatement.MatchingSubexpression m public String toString() { if (isUnidirectional()) - return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx); + if (isConditionalMultiRule()) + return fromRoot.toParsableString(ctx) + " => {" + toRoots.stream().map(stmt -> stmt.toParsableString(ctx)).collect(Collectors.joining("; ")) + "}"; + else + return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx); else return fromRoot.toParsableString(ctx) + " <=> " + toRoot.toParsableString(ctx); } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java index ec28e1fc07d..49dc722f059 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java @@ -363,6 +363,7 @@ public static RewriterRule parseRule(String expr, final RuleContext ctx) { Set allowedMultiRefs = Collections.emptySet(); boolean allowCombinations = false; boolean parsedExtendedHeader = false; + if (split[0].startsWith("AllowedMultiRefs:")) { split[0] = split[0].substring(17); String[] sSplit = split[0].split(","); @@ -375,6 +376,22 @@ public static RewriterRule parseRule(String expr, final RuleContext ctx) { allowCombinations = Boolean.parseBoolean(split[1]); parsedExtendedHeader = true; } + + int condIdxStart = -1; + for (int i = 2; i < split.length; i++) { + if (split[i].startsWith("{")) { + // Then we have a conditional rule + condIdxStart = i; + break; + } + } + + if (condIdxStart != -1) { + // Then we have a conditional rule + List toExprs = Arrays.asList(split).subList(condIdxStart+1, split.length-1); + return parseRule(split[condIdxStart-2], toExprs, allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, condIdxStart-2)); + } + return parseRule(split[split.length-3], split[split.length-1], allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, split.length-3)); } @@ -386,6 +403,10 @@ public static RewriterRule parseRule(String exprFrom, String exprTo, Set(), allowedMultiRefs, allowCombinations, varDefinitions); } + public static RewriterRule parseRule(String exprFrom, List exprsTo, Set allowedMultiRefs, boolean allowCombinations, final RuleContext ctx, String... varDefinitions) { + return parseRule(exprFrom, exprsTo, ctx, new HashMap<>(), allowedMultiRefs, allowCombinations, true, varDefinitions); + } + public static RewriterStatement parse(String expr, final RuleContext ctx, Map dataTypes, String... varDefinitions) { for (String def : varDefinitions) parseDataTypes(def, dataTypes, ctx); @@ -418,6 +439,39 @@ public static RewriterRule parseRule(String exprFrom, String exprTo, final RuleC return new RewriterRuleBuilder(ctx).completeRule(parsedFrom, parsedTo).withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations).setUnidirectional(true).build(); } + public static RewriterRule parseRule(String exprFrom, List exprsTo, final RuleContext ctx, Map dataTypes, Set allowedMultiRefs, boolean allowCombinations, boolean asConditional, String... varDefinitions) { + if (!asConditional && exprsTo.size() > 1) + throw new IllegalArgumentException(); + + for (String def : varDefinitions) + parseDataTypes(def, dataTypes, ctx); + + HashMap mmap = new HashMap<>(); + + RewriterStatement parsedFrom = parseExpression(exprFrom, mmap, dataTypes, ctx); + if (ctx.metaPropagator != null) { + parsedFrom = ctx.metaPropagator.apply(parsedFrom); + } + + List parsedTos = new ArrayList<>(); + for (String exprTo : exprsTo) { + RewriterStatement parsedTo = parseExpression(exprTo, mmap, dataTypes, ctx); + + if (ctx.metaPropagator != null) { + parsedTo = ctx.metaPropagator.apply(parsedTo); + parsedTo.prepareForHashing(); + parsedTo.recomputeHashCodes(ctx); + } + + parsedTos.add(parsedTo); + } + + return new RewriterRuleBuilder(ctx) + .completeConditionalRule(parsedFrom, parsedTos) + .withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations) + .setUnidirectional(true).build(); + } + /** * Parses an expression * @param expr the expression string. Note that all whitespaces have to already be removed diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java index b54cea97def..8cb5bab2f54 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java @@ -127,4 +127,20 @@ public void test3() { assert serialized.equals(newSerialized); } + + @Test + public void test4() { + String ruleStr1 = "MATRIX:W1_rand,tmp29911\n" + + "FLOAT:tmp65095\n" + + "\n" + + "*(tmp65095,%*%(W1_rand,t(tmp29911)))\n" + + "=>\n" + + "{\n" + + "t(%*%(*(tmp65095,tmp29911),t(W1_rand)))\n" + + "%*%(*(tmp65095,W1_rand),t(tmp29911))\n" + + "*(tmp65095,t(%*%(tmp29911,t(W1_rand))))\n" + + "}"; + RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); + System.out.println(rule1.toString()); + } }