From 31f802939703a933f8e45fc84807e308de84796c Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 29 Oct 2024 15:56:48 +0100 Subject: [PATCH] Some major performance improvements --- .../sysds/hops/rewriter/RewriterDataType.java | 3 +++ .../hops/rewriter/RewriterInstruction.java | 26 +++++++++++++------ .../hops/rewriter/RewriterRuleCollection.java | 7 ++++- .../hops/rewriter/RewriterStatement.java | 1 + .../sysds/hops/rewriter/RewriterUtils.java | 1 + .../rewrite/RewriterClusteringTest.java | 8 ++++++ 6 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java index 40695db7d8f..c1cafd9486d 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java @@ -30,6 +30,9 @@ public String getResultingDataType(final RuleContext ctx) { return type; } + @Override + public void refreshReturnType(final RuleContext ctx) {} + @Override public boolean isLiteral() { return literal != null && !(literal instanceof List); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java index 477a19cffbf..8d25bd9f9a6 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java @@ -18,6 +18,7 @@ public class RewriterInstruction extends RewriterStatement { private String id; + private String returnType; private String instr; //private RewriterDataType result = new RewriterDataType(); private ArrayList operands = new ArrayList<>(); @@ -25,8 +26,6 @@ public class RewriterInstruction extends RewriterStatement { private boolean consolidated = false; private int hashCode; - //private DualHashBidiMap links = null; - @Override public String getId() { return id; @@ -34,10 +33,20 @@ public String getId() { @Override public String getResultingDataType(final RuleContext ctx) { - if (isArgumentList()) { - return getOperands().stream().map(op -> op.getResultingDataType(ctx)).reduce(RewriterUtils::defaultTypeHierarchy).get() + "..."; - } - return ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx); + if (returnType != null) + return returnType; + + if (isArgumentList()) + returnType = getOperands().stream().map(op -> op.getResultingDataType(ctx)).reduce(RewriterUtils::defaultTypeHierarchy).get() + "..."; + else + returnType = ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx); + + return returnType; + } + + @Override + public void refreshReturnType(final RuleContext ctx) { + returnType = null; } @Override @@ -472,9 +481,10 @@ public boolean hasProperty(String property, final RuleContext ctx) { } public String trueInstruction() { - Object trueInstrObj = getMeta("trueInstr"); + // Legacy code + /*Object trueInstrObj = getMeta("trueInstr"); if (trueInstrObj != null && trueInstrObj instanceof String) - return (String)trueInstrObj; + return (String)trueInstrObj;*/ return instr; } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java index da5ae1e2d58..301ce8a9714 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java @@ -1179,7 +1179,7 @@ public static void canonicalExpandAfterFlattening(final List rules .parseGlobalVars("INT...:indices") .parseGlobalVars("FLOAT...:ops") .withParsedStatement("sum($1:_idxExpr(indices, +(ops)))", hooks) - .toParsedStatement("+($3:argList(sum($2:_idxExpr(indices, +(ops)))))", hooks) // The inner +(ops) is temporary and will be removed + .toParsedStatement("$4:+($3:argList(sum($2:_idxExpr(indices, +(ops)))))", hooks) // The inner +(ops) is temporary and will be removed .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) .apply(hooks.get(3).getId(), newArgList -> { RewriterStatement oldArgList = newArgList.getChild(0, 0, 1, 0); @@ -1196,6 +1196,11 @@ public static void canonicalExpandAfterFlattening(final List rules RewriterUtils.copyIndexList(newIdxExpr); newArgList.getOperands().add(newSum); } + + newArgList.refreshReturnType(ctx); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + stmt.refreshReturnType(ctx); }, true) .build() ); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java index b73ae7185b6..38ae6c05a9e 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java @@ -296,6 +296,7 @@ public void setLiteral(Object literal) { public abstract RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector, RewriterStatement parent, int pIdx); // Returns the new maxRefId abstract int toParsableString(StringBuilder builder, Map refs, int maxRefId, Map> vars, final RuleContext ctx); + abstract void refreshReturnType(final RuleContext ctx); public String toParsableString(final RuleContext ctx, boolean includeDefinitions) { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java index c9633e67aaf..7c34eb334d9 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java @@ -279,6 +279,7 @@ public static void mergeArgLists(RewriterStatement stmt, final RuleContext ctx) stmt.forEachPreOrder(el -> { tryFlattenNestedArgList(ctx, el, el, -1); tryFlattenNestedOperatorPatterns(ctx, el); + el.refreshReturnType(ctx); return true; }); diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java index 18e37a27f43..9dbeb040ad1 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java @@ -1,5 +1,6 @@ package org.apache.sysds.test.component.codegen.rewrite; +import org.apache.commons.lang3.mutable.MutableInt; import org.apache.commons.lang3.mutable.MutableLong; import org.apache.sysds.hops.rewriter.RewriterDatabase; import org.apache.sysds.hops.rewriter.RewriterHeuristic; @@ -61,9 +62,16 @@ public void testExpressionClustering() { RewriterDatabase canonicalExprDB = new RewriterDatabase(); List foundEquivalences = new ArrayList<>(); + int size = db.size(); + MutableInt ctr = new MutableInt(0); + db.forEach(expr -> { + if (ctr.incrementAndGet() % 10 == 0) + System.out.println("Done: " + ctr.intValue() + " / " + size); // First, build all possible subtrees List subExprs = RewriterUtils.generateSubtrees(expr, ctx); + if (subExprs.size() > 100) + System.out.println("Critical number of subtrees: " + subExprs.size()); //List subExprs = List.of(expr); long evaluationCtr = 0;