Skip to content

Commit

Permalink
Multirule parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 53e5853 commit 549d456
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ private RewriterStatement applyInplace(RewriterStatement.MatchingSubexpression m

public String toString() {
if (isUnidirectional())
return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx);
if (isConditionalMultiRule())
return fromRoot.toParsableString(ctx) + " => {" + toRoots.stream().map(stmt -> stmt.toParsableString(ctx)).collect(Collectors.joining("; ")) + "}";
else
return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx);
else
return fromRoot.toParsableString(ctx) + " <=> " + toRoot.toParsableString(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ public static RewriterRule parseRule(String expr, final RuleContext ctx) {
Set<Integer> allowedMultiRefs = Collections.emptySet();
boolean allowCombinations = false;
boolean parsedExtendedHeader = false;

if (split[0].startsWith("AllowedMultiRefs:")) {
split[0] = split[0].substring(17);
String[] sSplit = split[0].split(",");
Expand All @@ -375,6 +376,22 @@ public static RewriterRule parseRule(String expr, final RuleContext ctx) {
allowCombinations = Boolean.parseBoolean(split[1]);
parsedExtendedHeader = true;
}

int condIdxStart = -1;
for (int i = 2; i < split.length; i++) {
if (split[i].startsWith("{")) {
// Then we have a conditional rule
condIdxStart = i;
break;
}
}

if (condIdxStart != -1) {
// Then we have a conditional rule
List<String> toExprs = Arrays.asList(split).subList(condIdxStart+1, split.length-1);
return parseRule(split[condIdxStart-2], toExprs, allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, condIdxStart-2));
}

return parseRule(split[split.length-3], split[split.length-1], allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, split.length-3));
}

Expand All @@ -386,6 +403,10 @@ public static RewriterRule parseRule(String exprFrom, String exprTo, Set<Integer
return parseRule(exprFrom, exprTo, ctx, new HashMap<>(), allowedMultiRefs, allowCombinations, varDefinitions);
}

public static RewriterRule parseRule(String exprFrom, List<String> exprsTo, Set<Integer> allowedMultiRefs, boolean allowCombinations, final RuleContext ctx, String... varDefinitions) {
return parseRule(exprFrom, exprsTo, ctx, new HashMap<>(), allowedMultiRefs, allowCombinations, true, varDefinitions);
}

public static RewriterStatement parse(String expr, final RuleContext ctx, Map<String, RewriterStatement> dataTypes, String... varDefinitions) {
for (String def : varDefinitions)
parseDataTypes(def, dataTypes, ctx);
Expand Down Expand Up @@ -418,6 +439,39 @@ public static RewriterRule parseRule(String exprFrom, String exprTo, final RuleC
return new RewriterRuleBuilder(ctx).completeRule(parsedFrom, parsedTo).withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations).setUnidirectional(true).build();
}

public static RewriterRule parseRule(String exprFrom, List<String> exprsTo, final RuleContext ctx, Map<String, RewriterStatement> dataTypes, Set<Integer> allowedMultiRefs, boolean allowCombinations, boolean asConditional, String... varDefinitions) {
if (!asConditional && exprsTo.size() > 1)
throw new IllegalArgumentException();

for (String def : varDefinitions)
parseDataTypes(def, dataTypes, ctx);

HashMap<Integer, RewriterStatement> mmap = new HashMap<>();

RewriterStatement parsedFrom = parseExpression(exprFrom, mmap, dataTypes, ctx);
if (ctx.metaPropagator != null) {
parsedFrom = ctx.metaPropagator.apply(parsedFrom);
}

List<RewriterStatement> parsedTos = new ArrayList<>();
for (String exprTo : exprsTo) {
RewriterStatement parsedTo = parseExpression(exprTo, mmap, dataTypes, ctx);

if (ctx.metaPropagator != null) {
parsedTo = ctx.metaPropagator.apply(parsedTo);
parsedTo.prepareForHashing();
parsedTo.recomputeHashCodes(ctx);
}

parsedTos.add(parsedTo);
}

return new RewriterRuleBuilder(ctx)
.completeConditionalRule(parsedFrom, parsedTos)
.withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations)
.setUnidirectional(true).build();
}

/**
* Parses an expression
* @param expr the expression string. Note that all whitespaces have to already be removed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,20 @@ public void test3() {

assert serialized.equals(newSerialized);
}

@Test
public void test4() {
String ruleStr1 = "MATRIX:W1_rand,tmp29911\n" +
"FLOAT:tmp65095\n" +
"\n" +
"*(tmp65095,%*%(W1_rand,t(tmp29911)))\n" +
"=>\n" +
"{\n" +
"t(%*%(*(tmp65095,tmp29911),t(W1_rand)))\n" +
"%*%(*(tmp65095,W1_rand),t(tmp29911))\n" +
"*(tmp65095,t(%*%(tmp29911,t(W1_rand))))\n" +
"}";
RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx);
System.out.println(rule1.toString());
}
}

0 comments on commit 549d456

Please # to comment.