Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent a68d6b6 commit a012b32
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ else if (type.equals("INT"))
if (type.equals("FLOAT"))
return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMinFloat((double)num, stmt);
else if (type.equals("INT"))
return (num, stmt) -> num == null ? stmt.intLiteral() : foldMinInt((long)num, stmt);
return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMinInt((long)num, stmt);
break;
case "max":
if (type.equals("FLOAT"))
return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMaxFloat((double)num, stmt);
else if (type.equals("INT"))
return (num, stmt) -> num == null ? stmt.intLiteral() : foldMaxInt((long)num, stmt);
return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMaxInt((long)num, stmt);
break;
}

Expand Down Expand Up @@ -122,30 +122,30 @@ public static double foldSumFloat(double num, RewriterStatement next) {
}

public static long foldSumInt(long num, RewriterStatement next) {
return num + next.intLiteral();
return num + next.intLiteral(false);
}

public static double foldMulFloat(double num, RewriterStatement next) {
return num * next.floatLiteral();
}

public static long foldMulInt(long num, RewriterStatement next) {
return num * next.intLiteral();
return num * next.intLiteral(false);
}

public static double foldMinFloat(double num, RewriterStatement next) {
return Math.min(num, next.floatLiteral());
}

public static long foldMinInt(long num, RewriterStatement next) {
return Math.min(num, next.intLiteral());
return Math.min(num, next.intLiteral(false));
}

public static double foldMaxFloat(double num, RewriterStatement next) {
return Math.max(num, next.floatLiteral());
}

public static long foldMaxInt(long num, RewriterStatement next) {
return Math.max(num, next.intLiteral());
return Math.max(num, next.intLiteral(false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,13 @@ public Object apply( Object _hi ) {
hi = _applyRewrite89(hi); // +(-(0.0,b),A) => -(A,b)
} else {
if ( hi_1.getDataType() == Types.DataType.MATRIX ) {
System.out.println("HERE0");
if ( hi_1 instanceof UnaryOp ) {
System.out.println("a");
hi = _applyRewrite30(hi); // +(a,cast.MATRIX(0.0)) => cast.MATRIX(a)
hi = _applyRewrite72(hi); // +(a,cast.MATRIX(b)) => cast.MATRIX(+(a,b))
} else if ( hi_1 instanceof BinaryOp ) {
System.out.println("b");
if ( (( BinaryOp ) hi_1 ).getOp() == Types.OpOp2.MINUS ) {
if ( hi_1.getInput().size() == 2 ) {
Hop hi_1_0 = hi_1.getInput(0);
Expand Down Expand Up @@ -214,6 +217,7 @@ public Object apply( Object _hi ) {
}
}
} else if ( hi_1 instanceof ReorgOp ) {
System.out.println("c");
if ( (( ReorgOp ) hi_1 ).getOp() == Types.ReOrgOp.REV ) {
hi = _applyRewrite283(hi); // +(a,rev($1:-(b,C))) => -(+(a,b),rev(C))
hi = _applyRewrite288(hi); // +(a,rev($1:-(C,b))) => +(-(a,b),rev(C))
Expand All @@ -226,6 +230,7 @@ public Object apply( Object _hi ) {
hi = _applyRewrite356(hi); // +(a,t($1:+(C,b))) => +(+(a,b),t(C))
}
} else {
System.out.println("HERE1");
hi = _applyRewrite5(hi); // +(0.0,A) => A
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,15 @@ public Object getLiteral() {
}

@Override
public long intLiteral() {
public long intLiteral(boolean cast) {
if (getLiteral() instanceof Boolean)
return (boolean)getLiteral() ? 1 : 0;

if (cast && getLiteral() instanceof Double) {
double val = floatLiteral();
return (long)val;
}

return (long)getLiteral();
}

Expand All @@ -113,7 +119,7 @@ public boolean boolLiteral() {
return (boolean)getLiteral();
if (getLiteral() instanceof Long)
return (long)getLiteral() == 0L;
return (double)getLiteral() == 0.0;
return (double)getLiteral() == 0.0D;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public RewriterStatement getLiteralStatement() {
}

@Override
public long intLiteral() {
public long intLiteral(boolean cast) {
throw new UnsupportedOperationException();
}

Expand Down
19 changes: 11 additions & 8 deletions src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -274,27 +275,29 @@ public String serialize(final RuleContext ctx) {
return sb.toString();
}

public boolean generateCodeAndTest(boolean optimize, boolean print) {
public Set<RewriterRule> generateCodeAndTest(boolean optimize, boolean print) {
String javaCode = toJavaCode("MGeneratedRewriteClass", optimize, false, true);
Function<Hop, Hop> f = RewriterCodeGen.compile(javaCode, "MGeneratedRewriteClass");

if (f == null)
return false; // Then, the code could not compile
return null; // Then, the code could not compile

int origSize = rules.size();
//int origSize = rules.size();
Set<RewriterRule> removed = new HashSet<>();

for (int i = 0; i < rules.size(); i++) {
if (!RewriterRuleCreator.validateRuleApplicability(rules.get(i), ctx, print, f)) {
System.out.println("Faulty rule: " + rules.get(i));
rules.remove(i);
i--;
removed.add(rules.get(i));
//rules.remove(i);
//i--;
}
}

if (rules.size() != origSize)
accelerate();
//if (rules.size() != origSize)
// accelerate();

return true;
return removed;
}

public static RewriterRuleSet deserialize(String data, final RuleContext ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,10 @@ public void replace(RewriterStatement newStmt) {
public abstract boolean isLiteral();
public abstract Object getLiteral();
public abstract RewriterStatement getLiteralStatement();
public abstract long intLiteral();
public long intLiteral() {
return intLiteral(false);
}
public abstract long intLiteral(boolean cast);
public abstract double floatLiteral();
public abstract boolean boolLiteral();

Expand Down
Loading

0 comments on commit a012b32

Please # to comment.