Skip to content

Commit

Permalink
Some major performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaybit0 committed Jan 29, 2025
1 parent 0e15664 commit 31f8029
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ public String getResultingDataType(final RuleContext ctx) {
return type;
}

@Override
public void refreshReturnType(final RuleContext ctx) {}

@Override
public boolean isLiteral() {
return literal != null && !(literal instanceof List<?>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,35 @@
public class RewriterInstruction extends RewriterStatement {

private String id;
private String returnType;
private String instr;
//private RewriterDataType result = new RewriterDataType();
private ArrayList<RewriterStatement> operands = new ArrayList<>();
private Function<List<RewriterStatement>, Long> costFunction = null;
private boolean consolidated = false;
private int hashCode;

//private DualHashBidiMap<RewriterStatement, RewriterStatement> links = null;

@Override
public String getId() {
return id;
}

@Override
public String getResultingDataType(final RuleContext ctx) {
if (isArgumentList()) {
return getOperands().stream().map(op -> op.getResultingDataType(ctx)).reduce(RewriterUtils::defaultTypeHierarchy).get() + "...";
}
return ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx);
if (returnType != null)
return returnType;

if (isArgumentList())
returnType = getOperands().stream().map(op -> op.getResultingDataType(ctx)).reduce(RewriterUtils::defaultTypeHierarchy).get() + "...";
else
returnType = ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx);

return returnType;
}

@Override
public void refreshReturnType(final RuleContext ctx) {
returnType = null;
}

@Override
Expand Down Expand Up @@ -472,9 +481,10 @@ public boolean hasProperty(String property, final RuleContext ctx) {
}

public String trueInstruction() {
Object trueInstrObj = getMeta("trueInstr");
// Legacy code
/*Object trueInstrObj = getMeta("trueInstr");
if (trueInstrObj != null && trueInstrObj instanceof String)
return (String)trueInstrObj;
return (String)trueInstrObj;*/
return instr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ public static void canonicalExpandAfterFlattening(final List<RewriterRule> rules
.parseGlobalVars("INT...:indices")
.parseGlobalVars("FLOAT...:ops")
.withParsedStatement("sum($1:_idxExpr(indices, +(ops)))", hooks)
.toParsedStatement("+($3:argList(sum($2:_idxExpr(indices, +(ops)))))", hooks) // The inner +(ops) is temporary and will be removed
.toParsedStatement("$4:+($3:argList(sum($2:_idxExpr(indices, +(ops)))))", hooks) // The inner +(ops) is temporary and will be removed
.link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta)
.apply(hooks.get(3).getId(), newArgList -> {
RewriterStatement oldArgList = newArgList.getChild(0, 0, 1, 0);
Expand All @@ -1196,6 +1196,11 @@ public static void canonicalExpandAfterFlattening(final List<RewriterRule> rules
RewriterUtils.copyIndexList(newIdxExpr);
newArgList.getOperands().add(newSum);
}

newArgList.refreshReturnType(ctx);
}, true)
.apply(hooks.get(4).getId(), stmt -> {
stmt.refreshReturnType(ctx);
}, true)
.build()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ public void setLiteral(Object literal) {
public abstract RewriterStatement nestedCopyOrInject(Map<RewriterStatement, RewriterStatement> copiedObjects, TriFunction<RewriterStatement, RewriterStatement, Integer, RewriterStatement> injector, RewriterStatement parent, int pIdx);
// Returns the new maxRefId
abstract int toParsableString(StringBuilder builder, Map<RewriterRule.IdentityRewriterStatement, Integer> refs, int maxRefId, Map<String, Set<String>> vars, final RuleContext ctx);
abstract void refreshReturnType(final RuleContext ctx);

public String toParsableString(final RuleContext ctx, boolean includeDefinitions) {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ public static void mergeArgLists(RewriterStatement stmt, final RuleContext ctx)
stmt.forEachPreOrder(el -> {
tryFlattenNestedArgList(ctx, el, el, -1);
tryFlattenNestedOperatorPatterns(ctx, el);
el.refreshReturnType(ctx);
return true;
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.apache.sysds.test.component.codegen.rewrite;

import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.mutable.MutableLong;
import org.apache.sysds.hops.rewriter.RewriterDatabase;
import org.apache.sysds.hops.rewriter.RewriterHeuristic;
Expand Down Expand Up @@ -61,9 +62,16 @@ public void testExpressionClustering() {
RewriterDatabase canonicalExprDB = new RewriterDatabase();
List<RewriterStatement> foundEquivalences = new ArrayList<>();

int size = db.size();
MutableInt ctr = new MutableInt(0);

db.forEach(expr -> {
if (ctr.incrementAndGet() % 10 == 0)
System.out.println("Done: " + ctr.intValue() + " / " + size);
// First, build all possible subtrees
List<RewriterStatement> subExprs = RewriterUtils.generateSubtrees(expr, ctx);
if (subExprs.size() > 100)
System.out.println("Critical number of subtrees: " + subExprs.size());
//List<RewriterStatement> subExprs = List.of(expr);
long evaluationCtr = 0;

Expand Down

0 comments on commit 31f8029

Please # to comment.