Skip to content

Commit

Permalink
Further Code Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 3b07f75 commit c13a1a8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 726 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -495,6 +496,105 @@ public static int toBaseNNumber(int[] digits, int n) {
return out;
}

public static List<RewriterStatement> mergeSubtreeCombinations(RewriterStatement stmt, List<Integer> indices, List<List<RewriterStatement>> mList, final RuleContext ctx, int maximumCombinations) {
if (indices.isEmpty())
return List.of(stmt);

List<RewriterStatement> mergedTreeCombinations = new ArrayList<>();
RewriterUtils.cartesianProduct(mList, new RewriterStatement[mList.size()], stack -> {
RewriterStatement cpy = stmt.copyNode();
for (int i = 0; i < stack.length; i++)
cpy.getOperands().set(indices.get(i), stack[i]);
cpy.consolidate(ctx);
cpy.prepareForHashing();
cpy.recomputeHashCodes(ctx);
mergedTreeCombinations.add(cpy);
return mergedTreeCombinations.size() < maximumCombinations;
});

return mergedTreeCombinations;
}

public static List<RewriterStatement> generateSubtrees(RewriterStatement stmt, final RuleContext ctx, int maximumCombinations) {
List<RewriterStatement> l = generateSubtrees(stmt, new HashMap<>(), ctx, maximumCombinations);

if (ctx.metaPropagator != null)
l.forEach(subtree -> ctx.metaPropagator.apply(subtree));

return l.stream().map(subtree -> {
if (ctx.metaPropagator != null)
subtree = ctx.metaPropagator.apply(subtree);

subtree.prepareForHashing();
subtree.recomputeHashCodes(ctx);
return subtree;
}).collect(Collectors.toList());
}

private static Random rd = new Random();

private static List<RewriterStatement> generateSubtrees(RewriterStatement stmt, Map<RewriterStatement, List<RewriterStatement>> visited, final RuleContext ctx, int maxCombinations) {
if (stmt == null)
return Collections.emptyList();

RewriterStatement is = stmt;
List<RewriterStatement> alreadyVisited = visited.get(is);

if (alreadyVisited != null)
return alreadyVisited;

if (stmt.getOperands().size() == 0)
return List.of(stmt);

// Scan if operand is not a DataType
List<Integer> indices = new ArrayList<>();
for (int i = 0; i < stmt.getOperands().size(); i++) {
if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral())
indices.add(i);
}

int n = indices.size();
int totalSubsets = 1 << n;

List<RewriterStatement> mList = new ArrayList<>();

visited.put(is, mList);

List<List<RewriterStatement>> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList());
List<RewriterStatement> out = new ArrayList<>();

for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) {
List<List<RewriterStatement>> mOptionCpy = new ArrayList<>(mOptions);

for (int i = 0; i < n; i++) {
// Check if the i-th child is included in the current subset
if ((subsetMask & (1 << i)) == 0) {
String dt = stmt.getOperands().get(indices.get(i)).getResultingDataType(ctx);
String namePrefix = "tmp";
if (dt.equals("MATRIX"))
namePrefix = "M";
else if (dt.equals("FLOAT"))
namePrefix = "f";
else if (dt.equals("INT"))
namePrefix = "i";
else if (dt.equals("BOOL"))
namePrefix = "b";
RewriterDataType mT = new RewriterDataType().as(namePrefix + rd.nextInt(100000)).ofType(dt);
mT.consolidate(ctx);
mOptionCpy.set(i, List.of(mT));
}
}

out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations));
if (out.size() > maxCombinations) {
System.out.println("Aborting early due to too many combinations");
return out;
}
}

return out;
}

public static final class Operand {
public final String op;
public final int numArgs;
Expand Down
Loading

0 comments on commit c13a1a8

Please # to comment.