Skip to content

Commit

Permalink
Some more improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 44c85c3 commit 60436f3
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ public List<RewriterRule> createNonGenericRules(Map<String, Set<String>> funcMap
static class IdentityRewriterStatement {
public RewriterStatement stmt;

@Deprecated
public IdentityRewriterStatement(RewriterStatement stmt) {
this.stmt = stmt;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ private static List<RewriterStatement> generateSubtrees(RewriterStatement stmt,
// Scan if operand is not a DataType
List<Integer> indices = new ArrayList<>();
for (int i = 0; i < stmt.getOperands().size(); i++) {
if (stmt.getOperands().get(i).isInstruction() || stmt.isLiteral())
if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral())
indices.add(i);
}

Expand All @@ -1165,6 +1165,8 @@ private static List<RewriterStatement> generateSubtrees(RewriterStatement stmt,

List<List<RewriterStatement>> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList());
List<RewriterStatement> out = new ArrayList<>();
//System.out.println("Stmt: " + stmt.toParsableString(ctx));
//System.out.println("mOptions: " + mOptions);

for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) {
List<List<RewriterStatement>> mOptionCpy = new ArrayList<>(mOptions);
Expand All @@ -1178,6 +1180,7 @@ private static List<RewriterStatement> generateSubtrees(RewriterStatement stmt,
}
}

//System.out.println("mOptionCopy: " + mOptionCpy);
out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations));
if (out.size() > maxCombinations) {
System.out.println("Aborting early due to too many combinations");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void testExpressionClustering() {
db.parForEach(expr -> {
if (ctr.incrementAndGet() % 10 == 0)
System.out.println("Done: " + ctr.intValue() + " / " + size);
if (ctr.intValue() > 2000)
if (ctr.intValue() > 10000)
return; // Skip
// First, build all possible subtrees
//System.out.println("Eval:\n" + expr.toParsableString(ctx, true));
Expand Down Expand Up @@ -140,7 +140,7 @@ public void testExpressionClustering() {
totalCanonicalizationMillis.addAndGet(mCanonicalizationMillis);
});

printEquivalences(foundEquivalences, System.currentTimeMillis() - startTime, generatedExpressions.longValue(), evaluatedExpressions.longValue(), totalCanonicalizationMillis.longValue(), failures.longValue(), true);
printEquivalences(/*foundEquivalences*/ Collections.emptyList(), System.currentTimeMillis() - startTime, generatedExpressions.longValue(), evaluatedExpressions.longValue(), totalCanonicalizationMillis.longValue(), failures.longValue(), true);

System.out.println("===== SUGGESTED REWRITES =====");
List<Tuple5<Double, Long, Long, RewriterStatement, RewriterStatement>> rewrites = findSuggestedRewrites(foundEquivalences);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.apache.sysds.test.component.codegen.rewrite.functions;

import org.apache.sysds.hops.rewriter.RewriterStatement;
import org.apache.sysds.hops.rewriter.RewriterUtils;
import org.apache.sysds.hops.rewriter.RuleContext;
import org.junit.BeforeClass;
import org.junit.Test;

import java.util.List;
import java.util.function.Function;

public class SubtreeGeneratorTest {

private static RuleContext ctx;
private static Function<RewriterStatement, RewriterStatement> canonicalConverter;

@BeforeClass
public static void setup() {
ctx = RewriterUtils.buildDefaultContext();
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true);
}

@Test
public void test1() {
RewriterStatement stmt = RewriterUtils.parse("+(1, a)", ctx, "LITERAL_INT:1", "FLOAT:a");
List<RewriterStatement> subtrees = RewriterUtils.generateSubtrees(stmt, ctx, 100);

for (RewriterStatement sub : subtrees) {
System.out.println("==========");
System.out.println(sub.toParsableString(ctx, true));
}

assert subtrees.size() == 2;
}

@Test
public void test2() {
RewriterStatement stmt = RewriterUtils.parse("+(+(1, b), a)", ctx, "LITERAL_INT:1", "FLOAT:a,b");
List<RewriterStatement> subtrees = RewriterUtils.generateSubtrees(stmt, ctx, 100);

for (RewriterStatement sub : subtrees) {
System.out.println("==========");
System.out.println(sub.toParsableString(ctx, true));
}

assert subtrees.size() == 3;
}
}

0 comments on commit 60436f3

Please # to comment.