diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java index bd9d2f2008b..aedf5e633df 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java @@ -1312,9 +1312,7 @@ public static void pushdownStreamSelections(final List rules, fina }, true) .linkUnidirectional(hooks.get(1).getId(), hooks.get(2).getId(), lnk -> { RewriterStatement.transferMeta(lnk); - System.out.println("HERE"); - // TODO: Big issue when having multiple references to the same sub-dag for (int idx = 0; idx < 2; idx++) { RewriterStatement oldRef = lnk.oldStmt.getChild(idx); @@ -1340,10 +1338,7 @@ public static void pushdownStreamSelections(final List rules, fina } }, true) .apply(hooks.get(3).getId(), stmt -> { - System.out.println("BEFORE: " + stmt.toParsableString(ctx)); stmt.getOperands().set(0, stmt.getChild(0, 2)); - System.out.println("AFTER: " + stmt.toParsableString(ctx)); - System.out.println("Cnt: " + stmt.getChild(0).refCtr); }, true) .build() ); @@ -1379,6 +1374,21 @@ public static void pushdownStreamSelections(final List rules, fina RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, mStmtC).consolidate(ctx); final RewriterStatement newStmt = RewriterUtils.foldConstants(mStmt, ctx); + /*UUID oldRefId = (UUID)oldRef.getMeta("idxId"); + + RewriterStatement newOne = RewriterUtils.replaceReferenceAware(lnk.newStmt.get(0).getChild(2), stmt -> { + UUID idxId = (UUID) stmt.getMeta("idxId"); + if (idxId != null) { + if (idxId.equals(oldRefId)) + return newStmt; + } + + return null; + }); + + if (newOne != null) + lnk.newStmt.get(0).getOperands().set(2, newOne);*/ + // Replace all references to h with lnk.newStmt.get(0).getOperands().get(2).forEachPostOrder((el, pred) -> { for (int i = 0; i < el.getOperands().size(); i++) { diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java index c4cb84de082..d5999bae995 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java @@ -829,11 +829,11 @@ public static RewriterStatement replaceReferenceAware(RewriterStatement root, bo RewriterStatement newSub = replaceReferenceAware(root.getOperands().get(i), duplicateReferences, comparer, visited); if (newSub != null) { - System.out.println("NewSub: " + newSub); + //System.out.println("NewSub: " + newSub); if (duplicateReferences && newOne == null) { root = root.copyNode(); newOne = root; - System.out.println("Duplication required: " + root); + //System.out.println("Duplication required: " + root); } root.getOperands().set(i, newSub); @@ -844,6 +844,27 @@ public static RewriterStatement replaceReferenceAware(RewriterStatement root, bo return newOne; } + public static void unfoldExpressions(RewriterStatement root, RuleContext ctx) { + for (int i = 0; i < root.getOperands().size(); i++) { + RewriterStatement child = root.getChild(i); + if (child.isInstruction() && child.refCtr > 1) { + if (!child.trueInstruction().equals("_idx") + && !child.trueInstruction().equals("_m") + && !child.trueInstruction().equals("idxExpr") + //&& !child.trueInstruction().equals("argList") + && !child.trueInstruction().equals("_EClass")) { + RewriterStatement cpy = child.copyNode(); + root.getOperands().set(i, cpy); + child.refCtr--; + cpy.getOperands().forEach(op -> op.refCtr++); + //System.out.println("Copied: " + child.trueInstruction()); + } + } + + unfoldExpressions(child, ctx); + } + } + // Function to check if two lists match public static boolean findMatchingOrderings(List col1, List col2, T[] stack, BiFunction matcher, Function permutationEmitter, boolean symmetric) { if (col1.size() != col2.size()) @@ -1406,6 +1427,8 @@ public static Function buildCanonicalFormC RewriterUtils.mergeArgLists(stmt, ctx); stmt = stmt.getAssertions(ctx).cleanupEClasses(stmt); + unfoldExpressions(stmt, ctx); + stmt.prepareForHashing(); // TODO: After this, stuff like CSE, A-A = 0, etc. must still be applied @@ -1844,7 +1867,15 @@ private static void postCleanupIndexExpr(RewriterStatement cur) { if (idx1.isInstruction() && idx2.isInstruction() && idx1.trueInstruction().equals("_idx") && idx2.trueInstruction().equals("_idx")) { // Then we just choose the first index - cur.forEachPreOrder(cur2 -> { + cur.getChild(1).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(idx2)) + cur2.getOperands().set(i, idx1); + } + + return true; + }, true); + cur.getChild(2).forEachPreOrder(cur2 -> { for (int i = 0; i < cur2.getOperands().size(); i++) { if (cur2.getChild(i).equals(idx2)) cur2.getOperands().set(i, idx1); diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java index c54d701955e..874edfd3a25 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java @@ -413,8 +413,8 @@ public void testExactMatch() { System.out.println("=========="); System.out.println(stmt2.toParsableString(ctx, true)); - assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); - assert !stmt2.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + assert stmt2.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2)); } @Test