Skip to content

Commit

Permalink
Some bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 520ff96 commit dd5cad5
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class RewriterAlphabetEncoder {
private static final List<String> MATRIX = List.of("MATRIX");

private static Operand[] instructionAlphabet = new Operand[] {
null,
new Operand("+", 2, ALL_TYPES),
new Operand("-", 2, ALL_TYPES),
new Operand("*", 2, ALL_TYPES),
Expand Down Expand Up @@ -84,7 +85,7 @@ public static List<RewriterStatement> buildAssertionVariations(RewriterStatement
return true;
}, true);

if (interestingLeaves.size() < 2)
if (interestingLeaves.isEmpty())
return List.of(root);

List<RewriterStatement> out = new ArrayList<>();
Expand Down Expand Up @@ -195,6 +196,9 @@ public static List<RewriterStatement> buildVariations(RewriterStatement root, fi
}

public static List<RewriterStatement> buildAllPossibleDAGs(List<Operand> operands, final RuleContext ctx, boolean rename) {
if (operands == null)
return Collections.emptyList();

RewriterAlphabetEncoder.ctx = ctx;

List<RewriterStatement> allStmts = recursivelyFindAllCombinations(operands);
Expand Down Expand Up @@ -272,9 +276,17 @@ private static void forEachSlice(int startIdx, int pos, int maxIdx, int[] slices
public static List<Operand> decodeOrderedStatements(int stmt) {
int[] instructions = fromBaseNNumber(stmt, instructionAlphabet.length);
List<Operand> out = new ArrayList<>(instructions.length);

for (int i = 0; i < instructions.length; i++)
out.add(instructionAlphabet[instructions[i]]);
//System.out.println("StmtIdx: " + stmt);

for (int i = 0; i < instructions.length; i++) {
/*System.out.println("Idx: " + i);
System.out.println("digits[" + i + "]: " + instructions[i]);
System.out.println("As op: " + instructionAlphabet[instructions[i]]);*/
Operand toAdd = instructionAlphabet[instructions[i]];
if (toAdd == null)
return null;
out.add(toAdd);
}

return out;
}
Expand All @@ -284,18 +296,26 @@ public static int[] fromBaseNNumber(int l, int n) {
return new int[0];

// We put 1 as the last bit to signalize end of sequence
int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(l));
/*int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(l));
int maxRepr = 1 << (m - 1);
l = l ^ (1 << m);
int numDigits = (int)(Math.log(maxRepr) / Math.log(n)) + 1;
System.out.println("Bin: " + Integer.toBinaryString(l));
System.out.println("m: " + m);
System.out.println("l: " + l);*/

int numDigits = (int)(Math.log(l) / Math.log(n)) + 1;
int[] digits = new int[numDigits];

for (int i = numDigits - 1; i >= 0; i--) {
//System.out.println(l + " % " + n);
digits[i] = l % n;
l = l / n;
}

/*System.out.println("numDigits: " + numDigits);
System.out.println("digits[0]: " + digits[0]);*/

return digits;
}

Expand All @@ -305,16 +325,16 @@ public static int toBaseNNumber(int[] digits, int n) {

int multiplicator = 1;
int out = 0;
int maxPossible = 0;
//int maxPossible = 0;

for (int i = digits.length - 1; i >= 0; i--) {
out += multiplicator * digits[i];
maxPossible += multiplicator * (n - 1);
//maxPossible += multiplicator * (n - 1);
multiplicator *= n;
}

int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(maxPossible));
out |= (1 << m);
/*int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(maxPossible));
out |= (1 << m);*/

return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import org.apache.commons.lang3.mutable.MutableBoolean;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
Expand Down Expand Up @@ -68,7 +70,7 @@ public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFuncti
if (rule != null)
foundRewrite.setValue(true);

for (int i = 0; i < 1000 && rule != null; i++) {
for (int i = 0; i < 500 && rule != null; i++) {
//System.out.println("Pre-apply: " + rule.rule.getName());
/*if (currentStmt.toParsableString(ruleSet.getContext()).equals("%*%(X,[](B,1,ncol(X),1,ncol(B)))"))
System.out.println("test");*/
Expand All @@ -77,11 +79,17 @@ public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFuncti
currentStmt = rule.rule.apply(rule.matches.get(0), currentStmt, rule.forward, false);
//System.out.println("Now: " + currentStmt.toParsableString(ruleSet.getContext()));

if (handler != null && !handler.apply(currentStmt, rule.rule))
//transforms.add(currentStmt.toParsableString(ruleSet.getContext()));

if (handler != null && !handler.apply(currentStmt, rule.rule)) {
rule = null;
break;
}

if (!(currentStmt instanceof RewriterInstruction))
if (!(currentStmt instanceof RewriterInstruction)) {
rule = null;
break;
}

if (accelerated)
rule = ruleSet.acceleratedFindFirst(currentStmt);
Expand All @@ -90,7 +98,7 @@ public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFuncti
}

if (rule != null)
throw new IllegalArgumentException("Expression did not converge:\n" + currentStmt.toParsableString(ruleSet.getContext(), true));
throw new IllegalArgumentException("Expression did not converge:\n" + currentStmt.toParsableString(ruleSet.getContext(), true) + "\nRule: " + rule);

return currentStmt;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ public void testExpressionClustering() {
RewriterStatement canonicalForm = converter.apply(stmt);
computeCost(stmt, ctx);

List<RewriterStatement> equivalentExpressions = new ArrayList<>();
equivalentExpressions.add(stmt);
canonicalForm.unsafePutMeta("equivalentExpressions", equivalentExpressions);

// Insert the canonical form or retrieve the existing entry
RewriterStatement existingEntry = canonicalExprDB.insertOrReturn(ctx, canonicalForm);

if (existingEntry == null) {
List<RewriterStatement> equivalentExpressions = new ArrayList<>();
equivalentExpressions.add(stmt);
canonicalForm.unsafePutMeta("equivalentExpressions", equivalentExpressions);
} else {
List<RewriterStatement> equivalentExpressions = (List<RewriterStatement>) existingEntry.getMeta("equivalentExpressions");
if (existingEntry != null) {
equivalentExpressions = (List<RewriterStatement>) existingEntry.getMeta("equivalentExpressions");
equivalentExpressions.add(stmt);

if (equivalentExpressions.size() == 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,4 +890,14 @@ public void testSumEquality5() {

assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
}

@Test
public void testSimpleConvergence() {
RewriterStatement stmt1 = RewriterUtils.parse("sum(a)", ctx, "FLOAT:a");

stmt1 = canonicalConverter.apply(stmt1);

System.out.println("==========");
System.out.println(stmt1.toParsableString(ctx, true));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ public void testEncode1() {
@Test
public void testRandomStatementGeneration() {
int ctr = 0;
for (int i = 0; i < 1000; i++) {
for (int i = 1; i < 16; i++) {
List<RewriterAlphabetEncoder.Operand> ops = RewriterAlphabetEncoder.decodeOrderedStatements(i);
System.out.println("Idx: " + i);
System.out.println(ops);
//System.out.println("Idx: " + i);
//System.out.println(ops);
//System.out.println(RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, false).size());

for (RewriterStatement stmt : RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, true)) {
System.out.println("Base: " + stmt.toParsableString(ctx));
for (RewriterStatement sstmt : RewriterAlphabetEncoder.buildAssertionVariations(stmt, ctx, true)) {
canonicalConverter.apply(sstmt);
//System.out.println(sstmt.toParsableString(ctx));
System.out.println(sstmt.toParsableString(ctx));
ctr++;
}
}
Expand Down

0 comments on commit dd5cad5

Please # to comment.