From c13a1a815c88a1c23649b6aadc0acf80a8111ebc Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 28 Jan 2025 16:10:26 +0100 Subject: [PATCH] Further Code Cleanup --- .../rewriter/utils/RewriterSearchUtils.java | 100 +++ .../hops/rewriter/utils/RewriterUtils.java | 733 +----------------- .../rewrite/RewriterClusteringTest.java | 2 +- .../functions/SubtreeGeneratorTest.java | 7 +- 4 files changed, 116 insertions(+), 726 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java index 1af9888b056..09bd3e005fb 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java @@ -30,6 +30,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.stream.Collectors; @@ -495,6 +496,105 @@ public static int toBaseNNumber(int[] digits, int n) { return out; } + public static List mergeSubtreeCombinations(RewriterStatement stmt, List indices, List> mList, final RuleContext ctx, int maximumCombinations) { + if (indices.isEmpty()) + return List.of(stmt); + + List mergedTreeCombinations = new ArrayList<>(); + RewriterUtils.cartesianProduct(mList, new RewriterStatement[mList.size()], stack -> { + RewriterStatement cpy = stmt.copyNode(); + for (int i = 0; i < stack.length; i++) + cpy.getOperands().set(indices.get(i), stack[i]); + cpy.consolidate(ctx); + cpy.prepareForHashing(); + cpy.recomputeHashCodes(ctx); + mergedTreeCombinations.add(cpy); + return mergedTreeCombinations.size() < maximumCombinations; + }); + + return mergedTreeCombinations; + } + + public static List generateSubtrees(RewriterStatement stmt, final RuleContext ctx, int maximumCombinations) { + List l = generateSubtrees(stmt, new HashMap<>(), ctx, maximumCombinations); + + if (ctx.metaPropagator != null) + l.forEach(subtree -> ctx.metaPropagator.apply(subtree)); + + return l.stream().map(subtree -> { + if (ctx.metaPropagator != null) + subtree = ctx.metaPropagator.apply(subtree); + + subtree.prepareForHashing(); + subtree.recomputeHashCodes(ctx); + return subtree; + }).collect(Collectors.toList()); + } + + private static Random rd = new Random(); + + private static List generateSubtrees(RewriterStatement stmt, Map> visited, final RuleContext ctx, int maxCombinations) { + if (stmt == null) + return Collections.emptyList(); + + RewriterStatement is = stmt; + List alreadyVisited = visited.get(is); + + if (alreadyVisited != null) + return alreadyVisited; + + if (stmt.getOperands().size() == 0) + return List.of(stmt); + + // Scan if operand is not a DataType + List indices = new ArrayList<>(); + for (int i = 0; i < stmt.getOperands().size(); i++) { + if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral()) + indices.add(i); + } + + int n = indices.size(); + int totalSubsets = 1 << n; + + List mList = new ArrayList<>(); + + visited.put(is, mList); + + List> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList()); + List out = new ArrayList<>(); + + for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) { + List> mOptionCpy = new ArrayList<>(mOptions); + + for (int i = 0; i < n; i++) { + // Check if the i-th child is included in the current subset + if ((subsetMask & (1 << i)) == 0) { + String dt = stmt.getOperands().get(indices.get(i)).getResultingDataType(ctx); + String namePrefix = "tmp"; + if (dt.equals("MATRIX")) + namePrefix = "M"; + else if (dt.equals("FLOAT")) + namePrefix = "f"; + else if (dt.equals("INT")) + namePrefix = "i"; + else if (dt.equals("BOOL")) + namePrefix = "b"; + RewriterDataType mT = new RewriterDataType().as(namePrefix + rd.nextInt(100000)).ofType(dt); + mT.consolidate(ctx); + mOptionCpy.set(i, List.of(mT)); + } + } + + out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations)); + if (out.size() > maxCombinations) { + System.out.println("Aborting early due to too many combinations"); + return out; + } + } + + return out; + } + public static final class Operand { public final String op; public final int numArgs; diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java index b5c95567678..4851ba3de53 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java @@ -67,233 +67,19 @@ public static String typedToUntypedInstruction(String instr) { return instr.substring(0, instr.indexOf('(')); } - public static Function propertyExtractor(final List desiredProperties, final RuleContext ctx) { - return el -> { - if (el instanceof RewriterInstruction) { - Set properties = ((RewriterInstruction) el).getProperties(ctx); - String trueInstr = el.trueTypedInstruction(ctx); - for (String desiredProperty : desiredProperties) { - if (trueInstr.equals(desiredProperty) || (properties != null && properties.contains(desiredProperty))) { - System.out.println("Found property: " + desiredProperty + " (for " + el + ")"); - String oldInstr = ((RewriterInstruction) el).changeConsolidatedInstruction(desiredProperty, ctx); - if (el.getMeta("trueInstr") == null) { - el.unsafePutMeta("trueInstr", oldInstr); - el.unsafePutMeta("trueName", oldInstr); - } - break; - } - } - } - return true; - }; - } - public static BiFunction binaryStringRepr(String op) { return (stmt, ctx) -> { - List operands = ((RewriterInstruction)stmt).getOperands(); + List operands = stmt.getOperands(); String op1Str = operands.get(0).toString(ctx); - if (operands.get(0) instanceof RewriterInstruction && ((RewriterInstruction)operands.get(0)).getOperands().size() > 1) + if (operands.get(0) instanceof RewriterInstruction && operands.get(0).getOperands().size() > 1) op1Str = "(" + op1Str + ")"; String op2Str = operands.get(1).toString(ctx); - if (operands.get(1) instanceof RewriterInstruction && ((RewriterInstruction)operands.get(1)).getOperands().size() > 1) + if (operands.get(1) instanceof RewriterInstruction && operands.get(1).getOperands().size() > 1) op2Str = "(" + op2Str + ")"; return op1Str + op + op2Str; }; } - public static BiFunction wrappedBinaryStringRepr(String op) { - return (stmt, ctx) -> { - List operands = ((RewriterInstruction)stmt).getOperands(); - return "(" + operands.get(0).toString(ctx) + ")" + op + "(" + operands.get(1).toString(ctx) + ")"; - }; - } - - // No longer maintained - @Deprecated - public static RewriterStatement buildFusedPlan(RewriterStatement origStatement, final RuleContext ctx) { - RewriterStatement cpy = origStatement.nestedCopy(true); - MutableObject mCpy = new MutableObject<>(cpy); - - Map, List> mmap = eraseAccessTypes(mCpy, ctx); - cpy = mCpy.getValue(); - - // Identify common element wise accesses (e.g. A[i, j] + B[i, j] for all i, j) - //Map, List> mmap = new HashMap<>(); - - /*for (Tuple3 mTuple : mSet.keySet()) { - List accesses = mmap.compute(new Tuple2<>(mTuple._2(), mTuple._3()), (k, v) -> v == null ? new ArrayList<>() : v); - accesses.add(mTuple._1().stmt); - } - - List fuseList = new ArrayList<>(); - - MutableObject mParent = new MutableObject<>(cpy); - - cpy.forEachPreOrder((current, parent, pIdx) -> { - if (!current.isInstruction()) - return true; - - if (current.trueInstruction().equals("_m")) { - if (parent != null) - parent.getOperands().set(pIdx, current.getOperands().get(2)); - else - mParent.setValue(current.getOperands().get(2)); - } - - return true; - }); - - mmap.forEach((k, v) -> { - HashMap args = new HashMap<>(); - args.put("idx1", k._1.stmt); - args.put("idx2", k._2.stmt); - args.put("valueFn", ); - RewriterStatement vFn = cpy. - RewriterStatement newStmt = parse("_accessNary(idx1, idx2, valueFn)", ctx, ); - fuseList.add(); - });*/ - - // - - if (mmap.size() == 1) { - Map.Entry, List> entry = mmap.entrySet().iterator().next(); - HashMap args = new HashMap<>(); - - RewriterStatement mS = null; - - if (cpy.isInstruction()) { - if (cpy.trueInstruction().equals("_m")) { - args.put("stmt", cpy.getOperands().get(2)); - args.put("first", entry.getValue().get(0)); - - mS = RewriterUtils.parse("_map(argList(first), stmt)", ctx, args); - mS.getOperands().get(0).getOperands().addAll(entry.getValue().subList(1, entry.getValue().size())); - } else if (cpy.trueInstruction().equals("sum")) { - args.put("stmt", cpy.getOperands().get(0)); - args.put("first", entry.getValue().get(0)); - - System.out.println(args.get("stmt")); - mS = RewriterUtils.parse("_reduce(argList(first), +(_cur(), stmt))", ctx, args); - mS.getOperands().get(0).getOperands().addAll(entry.getValue().subList(1, entry.getValue().size())); - } - } - - return mS; - } - - return null; - } - - public static Map, List> eraseAccessTypes(MutableObject stmt, final RuleContext ctx) { - //Map, RewriterStatement> out = new HashMap<>(); - - Map, List> rewrites = new HashMap<>(); - - HashMap hooks = new HashMap<>(); - - List rules = new ArrayList<>(); - - rules.add(new RewriterRuleBuilder(ctx) - .setUnidirectional(true) - .parseGlobalVars("MATRIX:A") - .parseGlobalVars("INT:i,j") - .parseGlobalVars("FLOAT:v") - .withParsedStatement("[](A, i, j)") - .toParsedStatement("$1:_v(A)", hooks) - .iff(match -> { - List ops = match.getMatchRoot().getOperands(); - return (ops.get(0).isInstruction() && ops.get(0).trueInstruction().equals("_idx")) - || (ops.get(1).isInstruction() && ops.get(1).trueInstruction().equals("_idx")); - }, true) - .apply(hooks.get(1).getId(), (t, m) -> { - t.unsafePutMeta("data", m.getMatchRoot().getOperands().get(0)); - t.unsafePutMeta("idx1", m.getMatchRoot().getOperands().get(1)); - t.unsafePutMeta("idx2", m.getMatchRoot().getOperands().get(2)); - - RewriterStatement idx1 = m.getMatchRoot().getOperands().get(1); - RewriterStatement idx2 = m.getMatchRoot().getOperands().get(2); - Tuple2 mT = new Tuple2<>(idx1, idx2); - - List r = rewrites.get(mT); - - if (r == null) { - r = new ArrayList<>(); - rewrites.put(mT, r); - } - - r.add(t); - }, true) - .build()); - - rules.add(new RewriterRuleBuilder(ctx) - .setUnidirectional(true) - .parseGlobalVars("MATRIX:A") - .parseGlobalVars("INT...:i") - .parseGlobalVars("FLOAT:v") - .withParsedStatement("_idxExpr(i, v)") - .toParsedStatement("$1:v", hooks) - .iff(match -> { - List ops = match.getMatchRoot().getOperands().get(0).getOperands(); - return ops.stream().anyMatch(op -> op.isInstruction() && op.trueInstruction().equals("_idx")); - }, true) - .build()); - - rules.add(new RewriterRuleBuilder(ctx) - .setUnidirectional(true) - .parseGlobalVars("MATRIX:A") - .parseGlobalVars("INT...:i,j") - .parseGlobalVars("FLOAT*:v") - .withParsedStatement("_idxExpr(i, v)") - .toParsedStatement("v", hooks) - .iff(match -> { - List ops = match.getMatchRoot().getOperands().get(0).getOperands(); - return ops.stream().anyMatch(op -> op.isInstruction() && op.trueInstruction().equals("_idx")); - }, true) - .build()); - - RewriterRuleSet rs = new RewriterRuleSet(ctx, rules); - RewriterHeuristic heur = new RewriterHeuristic(rs, true); - - stmt.setValue(heur.apply(stmt.getValue())); - - return rewrites; - - /*stmt.getValue().forEachPostOrder((current, parent, pIdx) -> { - if (!current.isInstruction()) - return; - - if (current.trueInstruction().equals("[]")) { - boolean hasIndex = false; - if (current.getOperands().get(1).isInstruction() && current.getOperands().get(1).trueInstruction().equals("_idx")) - hasIndex = true; - - if (current.getOperands().get(2).isInstruction() && current.getOperands().get(2).trueInstruction().equals("_idx")) - hasIndex = true; - - if (hasIndex) { - current.getOperands().get(0).unsafePutMeta("idx1", current.getOperands().get(1)); - current.getOperands().get(0).unsafePutMeta("idx2", current.getOperands().get(2)); - out.put(new Tuple3<>(new RewriterRule.IdentityRewriterStatement(current.getOperands().get(0)), - new RewriterRule.IdentityRewriterStatement(current.getOperands().get(1)), - new RewriterRule.IdentityRewriterStatement(current.getOperands().get(2))), - current.getOperands().get(0)); - - if (parent != null) - parent.getOperands().set(pIdx, current.getOperands().get(0)); - else - stmt.setValue(current.getOperands().get(0)); - } - } else if (current.trueInstruction().equals("idxExpr")) { - if (parent != null) - parent.getOperands().set(pIdx, current.getOperands().get(1)); - else - stmt.setValue(current.getOperands().get(1)); - } - }); - - return out;*/ - } - public static void mergeArgLists(RewriterStatement stmt, final RuleContext ctx) { stmt.forEachPreOrder(el -> { @@ -488,7 +274,7 @@ public static RewriterRule parseRule(String exprFrom, List exprsTo, fina /** * Parses an expression - * @param expr the expression string. Note that all whitespaces have to already be removed + * @param expr the expression string * @param refmap test * @param dataTypes data type * @param ctx context @@ -524,7 +310,6 @@ private static RewriterStatement doParseExpression(MutableObject mexpr, throw new IllegalArgumentException("Variable '$" + n + "' does not exist!"); return var; - //throw new IllegalArgumentException("Expected the token ':'"); } String remainder = expr.substring(matcher.end() + 1); mexpr.setValue(remainder); @@ -539,7 +324,7 @@ private static RewriterStatement doParseExpression(MutableObject mexpr, } } - public static boolean parseDataTypes(String expr, Map dataTypes, /*List> matrixTypes,*/ final RuleContext ctx) { + public static boolean parseDataTypes(String expr, Map dataTypes, final RuleContext ctx) { RuleContext.currentContext = ctx; Pattern pattern = Pattern.compile("([A-Za-z0-9]|_|\\.|\\*|\\?)([A-Za-z0-9]|_|\\.|\\*|-)*"); Matcher matcher = pattern.matcher(expr); @@ -578,21 +363,7 @@ public static boolean parseDataTypes(String expr, Map dt = new RewriterDataType().as(varName).ofType("BOOL").asLiteral(Boolean.parseBoolean(varName)); } else if (floatLiteral) { dt = new RewriterDataType().as(varName).ofType("FLOAT").asLiteral(Double.parseDouble(varName)); - } /*else if (dType.equals("MATRIX")) { - // TODO - int matType = 0; - if (varName.startsWith("rowVec.")) { - matType = 1; - varName = varName.substring(7); - } else if (varName.startsWith("colVec.")) { - matType = 2; - varName = varName.substring(7); - } - - dt = new RewriterDataType().as(varName).ofType(dType); - - //matrixModes.add(new Tuple2<>(dt, matType)); - }*/ else { + } else { dt = new RewriterDataType().as(varName).ofType(dType); } @@ -618,8 +389,6 @@ private static RewriterStatement parseRawExpression(MutableObject mexpr, Pattern pattern = Pattern.compile("^[^(),:]+"); Matcher matcher = pattern.matcher(expr); - - if (matcher.find()) { String token = matcher.group(); String remainder = expr.substring(matcher.end()); @@ -703,26 +472,6 @@ private static void handleSpecialInstructions(RewriterInstruction instr) { } } - public static HashMap> createIndex(RewriterStatement stmt, final RuleContext ctx) { - HashMap> index = new HashMap<>(); - stmt.forEachPreOrderWithDuplicates(mstmt -> { - if (mstmt instanceof RewriterInstruction) { - RewriterInstruction instr = (RewriterInstruction)mstmt; - index.compute(instr.trueTypedInstruction(ctx), (k, v) -> { - if (v == null) { - return List.of(mstmt); - } else { - if (v.stream().noneMatch(el -> el == instr)) - v.add(instr); - return v; - } - }); - } - return true; - }); - return index; - } - public static void buildBinaryAlgebraInstructions(StringBuilder sb, String instr, List instructions) { for (String arg1 : instructions) { for (String arg2 : instructions) { @@ -738,14 +487,6 @@ else if (arg1.equals("FLOAT") || arg2.equals("FLOAT")) } } - public static void buildBinaryBoolInstructions(StringBuilder sb, String instr, List instructions) { - for (String arg1 : instructions) { - for (String arg2 : instructions) { - sb.append(instr + "(" + arg1 + "," + arg2 + ")::BOOL\n"); - } - } - } - public static void buildTernaryPermutations(List args, TriConsumer func) { buildBinaryPermutations(args, (t1, t2) -> args.forEach(t3 -> func.accept(t1, t2, t3))); } @@ -828,27 +569,7 @@ public static void putAsDefaultBinaryPrintable(List instrs, List putAsBinaryPrintable(instr, types, funcs, binaryStringRepr(" " + instr + " ")); } - public static HashMap> mapToImplementedFunctions(final RuleContext ctx) { - HashMap> out = new HashMap<>(); - Set superTypes = new HashSet<>(); - - for (Map.Entry entry : ctx.instrTypes.entrySet()) { - Set props = ctx.instrProperties.get(entry.getKey()); - if (props != null && !props.isEmpty()) { - for (String prop : props) { - Set impl = out.computeIfAbsent(prop, k -> new HashSet<>()); - impl.add(typedToUntypedInstruction(entry.getKey())); - superTypes.add(typedToUntypedInstruction(prop)); - } - } - } - - for (Map.Entry> entry : out.entrySet()) - entry.getValue().removeAll(superTypes); - - return out; - } - + // Updates the references (including metadata UUIDs) for a copied _idxExpr(args(_idx(...),...),...) public static void copyIndexList(RewriterStatement idxExprRoot) { if (!idxExprRoot.isInstruction() || !idxExprRoot.trueInstruction().equals("_idxExpr")) throw new IllegalArgumentException(); @@ -887,22 +608,11 @@ public static void copyIndexList(RewriterStatement idxExprRoot) { idxExprRoot.getOperands().set(1, out); } - public static void retargetIndexExpressions(RewriterStatement rootExpr, UUID oldIdxId, RewriterStatement newStatement) { - RewriterUtils.replaceReferenceAware(rootExpr, stmt -> { - UUID idxId = (UUID) stmt.getMeta("idxId"); - if (idxId != null) { - if (idxId.equals(oldIdxId)) - return newStatement; - } - - return null; - }); - } - public static RewriterStatement replaceReferenceAware(RewriterStatement root, Function comparer) { return replaceReferenceAware(root, false, comparer, new HashMap<>()); } + // Replaces elements in a DAG. If a parent item has multiple references, the entire path is updated public static RewriterStatement replaceReferenceAware(RewriterStatement root, boolean duplicateReferences, Function comparer, HashMap visited) { if (visited.containsKey(root)) return visited.get(root); @@ -922,11 +632,9 @@ public static RewriterStatement replaceReferenceAware(RewriterStatement root, bo RewriterStatement newSub = replaceReferenceAware(root.getOperands().get(i), duplicateReferences, comparer, visited); if (newSub != null) { - //System.out.println("NewSub: " + newSub); if (duplicateReferences && newOne == null) { root = root.copyNode(); newOne = root; - //System.out.println("Duplication required: " + root); } root.getOperands().set(i, newSub); @@ -937,6 +645,7 @@ public static RewriterStatement replaceReferenceAware(RewriterStatement root, bo return newOne; } + // Deduplicates the DAG (removes duplicate references with new nodes except for leaf data-types) public static void unfoldExpressions(RewriterStatement root, RuleContext ctx) { for (int i = 0; i < root.getOperands().size(); i++) { RewriterStatement child = root.getChild(i); @@ -945,13 +654,11 @@ public static void unfoldExpressions(RewriterStatement root, RuleContext ctx) { && !child.trueInstruction().equals("_m") && !child.trueInstruction().equals("idxExpr") && !child.trueInstruction().equals("rand") - //&& !child.trueInstruction().equals("argList") && !child.trueInstruction().equals("_EClass")) { RewriterStatement cpy = child.copyNode(); root.getOperands().set(i, cpy); child.refCtr--; cpy.getOperands().forEach(op -> op.refCtr++); - //System.out.println("Copied: " + child.trueInstruction()); } } @@ -959,64 +666,6 @@ public static void unfoldExpressions(RewriterStatement root, RuleContext ctx) { } } - // Function to check if two lists match - public static boolean findMatchingOrderings(List col1, List col2, T[] stack, BiFunction matcher, Function permutationEmitter, boolean symmetric) { - if (col1.size() != col2.size()) - return false; // Sizes must match - - if (stack.length < col2.size()) - throw new IllegalArgumentException("Mismatching stack sizes!"); - - if (col1.size() == 1) { - if (matcher.apply(col1.get(0), col2.get(0))) { - stack[0] = col2.get(0); - permutationEmitter.apply(stack); - return true; - } - - return false; - } - - // We need to get true on the diagonal for it to be a valid permutation - List> possiblePermutations = new ArrayList<>(Collections.nCopies(col1.size(), null)); - - boolean anyMatch; - - for (int i = 0; i < col1.size(); i++) { - anyMatch = false; - - for (int j = 0; j < col2.size(); j++) { - if (j > i && symmetric) - break; - - if (matcher.apply(col1.get(i), col2.get(j))) { - if (possiblePermutations.get(i) == null) - possiblePermutations.set(i, new ArrayList<>()); - - possiblePermutations.get(i).add(j); - - if (symmetric) { - if (possiblePermutations.get(j) == null) - possiblePermutations.set(j, new ArrayList<>()); - possiblePermutations.get(j).add(i); - } - - anyMatch = true; - } - } - - if (!anyMatch) // Then there cannot be a matching permutation - return false; - } - - // Start recursive matching - return cartesianProduct(possiblePermutations, new Integer[possiblePermutations.size()], arrangement -> { - for (int i = 0; i < col2.size(); i++) - stack[i] = col2.get(arrangement[i]); - return permutationEmitter.apply(stack); - }); - } - public static boolean cartesianProduct(List> list, T[] stack, Function emitter) { if (list.size() == 0) return false; @@ -1053,174 +702,6 @@ private static boolean _cartesianProduct(int index, List> sets, T[] return matchFound; } - // TODO: This is broken --> remove - public static void topologicalSort(RewriterStatement stmt, final RuleContext ctx, BiFunction arrangable) { - MutableInt nameCtr = new MutableInt(); - stmt.forEachPostOrderWithDuplicates((el, parent, pIdx) -> { - if (el.getOperands().isEmpty()) { - el.unsafePutMeta("_tempName", nameCtr.intValue()); - nameCtr.increment(); - } else if (parent != null && arrangable.apply(el, parent)) { - el.unsafePutMeta("_tempName", nameCtr.intValue()); - nameCtr.increment(); - } - }); - - //Map> votes = new HashMap<>(); - //Map> gotRatedBy = new HashMap<>(); - //List> uncertainStatements = new ArrayList<>(); - - // First pass (try to figure out everything) - traversePostOrderWithDepthInfo(stmt, null, (el, depth, parent) -> { - if (el.getOperands() == null) - return; - - RewriterStatement voter = el; - createHierarchy(ctx, el, el.getOperands()); - - /*if (votes.containsKey(voter)) - return; - - if (arrangable.apply(el, parent)) { - List> uStatements = createHierarchy(ctx, el, el.getOperands()); - if (uStatements.size() > 0) { - uStatements.forEach(e -> System.out.println("Uncertain: " + e.stream().map(t -> t.stmt).collect(Collectors.toList()))); - uncertainStatements.addAll(uStatements); - } - } else { - Map ratings = new HashMap<>(); - votes.put(voter, ratings); - - for (int i = 0; i < el.getOperands().size(); i++) { - RewriterRule.IdentityRewriterStatement toRate = new RewriterRule.IdentityRewriterStatement(el.getOperands().get(i)); - - if (votes.containsKey(toRate)) - continue; - - ratings.put(toRate, i); - - Set ratedBy = gotRatedBy.get(toRate); - - if (ratedBy == null) { - ratedBy = new HashSet<>(); - gotRatedBy.put(toRate, ratedBy); - } - - ratedBy.add(voter); - } - }*/ - }, 0); - - // TODO: Erase temp names - - /*while (!uncertainStatements.isEmpty()) { - // Now, try to resolve the conflicts deterministically using element-wise comparison - Map, Integer> orderSet = new HashMap<>(); - - for (Set requiredComparisons : uncertainStatements) { - forEachDistinctBinaryCombination(new ArrayList<>(requiredComparisons), (s1, s2) -> { - Optional myOpt = compareStatements(s1, s2, votes, gotRatedBy); - if (myOpt.isPresent()) { - orderSet.put(new Tuple2<>(s1, s2), myOpt); - orderSet.put(new Tuple2<>(s2, s1), Optional.of(!myOpt.get())); - } else { - orderSet.put(new Tuple2<>(s1, s2), Optional.empty()); - } - }); - } - }*/ - - // Trigger a recomputation of the hash codes - stmt.prepareForHashing(); - stmt.recomputeHashCodes(ctx); - } - - public static void forEachDistinctBinaryCombination(List l, BiConsumer consumer) { - for (int i = 0; i < l.size(); i++) - for (int j = l.size() - 1; j > i; j--) - consumer.accept(l.get(i), l.get(j)); - } - - private static void traversePostOrderWithDepthInfo(RewriterStatement stmt, RewriterStatement parent, TriConsumer consumer, int currentDepth) { - if (stmt.getOperands() != null) - stmt.getOperands().forEach(el -> traversePostOrderWithDepthInfo(el, stmt, consumer, currentDepth + 1)); - - consumer.accept(stmt, currentDepth, parent); - } - - // Returns the range of uncertain elements [start, end) - public static void createHierarchy(final RuleContext ctx, RewriterStatement voter, List level) { - if (level.isEmpty()) - return; - - //level.sort(Comparator.comparing(el -> toOrderString(ctx, el))); - level.sort((el1, el2) -> compare(el1, el2, ctx)); - - /*List> ranges = new ArrayList<>(); - int currentRangeStart = 0; - - RewriterRule.IdentityRewriterStatement voterIds = new RewriterRule.IdentityRewriterStatement(voter); - Map votes = new HashMap<>(); - - { - RewriterRule.IdentityRewriterStatement firstIds = new RewriterRule.IdentityRewriterStatement(level.get(0)); - - Set voters = gotRatedBy.get(firstIds); - - if (voters == null) { - voters = new HashSet<>(); - gotRatedBy.put(firstIds, voters); - } - - voters.add(voterIds); - - allVotes.put(firstIds, votes); - votes.put(firstIds, 0); - } - - for (int i = 1; i < level.size(); i++) { - System.out.println(toOrderString(ctx, level.get(i-1)) + " <=> " + toOrderString(ctx, level.get(i))); - if (compare(level.get(i-1), level.get(i), ctx) == 0) { - if (i - currentRangeStart > 1) { - Set mSet = level.subList(currentRangeStart, i).stream().map(RewriterRule.IdentityRewriterStatement::new).collect(Collectors.toSet()); - - if (mSet.size() > 1) - ranges.add(mSet); - - System.out.println("E-Set: " + mSet.stream().map(id -> id.stmt.toParsableString(ctx, false)).collect(Collectors.toList())); - - currentRangeStart = i; - } - } - - RewriterRule.IdentityRewriterStatement ids = new RewriterRule.IdentityRewriterStatement(level.get(i)); - votes.put(ids, currentRangeStart); - - Set voters = gotRatedBy.get(ids); - - if (voters == null) { - voters = new HashSet<>(); - gotRatedBy.put(ids, voters); - } - - voters.add(voterIds); - } - - if (level.size() - currentRangeStart > 1) { - Set mSet = level - .subList(currentRangeStart, level.size()) - .stream().map(RewriterRule.IdentityRewriterStatement::new) - .collect(Collectors.toSet()); - - if (mSet.size() > 1) - ranges.add(mSet); - - System.out.println("E-Set: " + mSet.stream().map(id -> id.stmt.toParsableString(ctx, false)).collect(Collectors.toList())); - } - - return ranges;*/ - } - public static boolean isImplicitlyConvertible(String typeFrom, String typeTo) { if (typeFrom.equals(typeTo)) return true; @@ -1250,163 +731,6 @@ public static Object literalAs(String type, RewriterDataType literal) { } } - public static int compare(RewriterStatement stmt1, RewriterStatement stmt2, /*RewriterStatement p1, RewriterStatement p2, Map, Integer> globalOrders, BiFunction arrangable,*/ final RuleContext ctx) { - /*boolean arrangable1 = arrangable.apply(stmt1, p1); - boolean arrangable2 = arrangable.apply(stmt2, p2); - - if (arrangable1) { - if (!arrangable2) - return 1; - } else { - if (arrangable2) - return -1; - } - - RewriterRule.IdentityRewriterStatement id1 = new RewriterRule.IdentityRewriterStatement(stmt1); - RewriterRule.IdentityRewriterStatement id2 = new RewriterRule.IdentityRewriterStatement(stmt2); - - if (!globalOrders.isEmpty()) { - Integer result = globalOrders.get(new Tuple2<>(id1, id2)); - - if (result == null) - result = globalOrders.get(new Tuple2<>(id2, id1)); - - if (result != null) - return result; - }*/ - - int comp = toOrderString(ctx, stmt1).compareTo(toOrderString(ctx, stmt2)); - - if (comp != 0 || stmt1.getOperands().isEmpty()) - return comp; - - for (int i = 0; i < stmt1.getOperands().size() && comp == 0; i++) - comp = compare(stmt1.getOperands().get(i), stmt2.getOperands().get(i), ctx); - - if (comp == 0) { - Integer mName1 = (Integer)stmt1.getMeta("_tempName"); - - if (mName1 == null) - return 0; - - return mName1.toString().compareTo(stmt2.getMeta("_tempName").toString()); - } - - return comp; - } - - public static String toOrderString(final RuleContext ctx, RewriterStatement stmt) { - return toOrderString(ctx, stmt, false); - } - - public static String toOrderString(final RuleContext ctx, RewriterStatement stmt, boolean extendIfPossible) { - if (stmt.isInstruction()) { - Integer mName = (Integer)stmt.getMeta("_tempName"); - return stmt.getResultingDataType(ctx) + ":" + stmt.trueTypedInstruction(ctx) + "[" + stmt.refCtr + "](" + stmt.getOperands().size() + ")" + (mName == null ? "" : mName) + ";"; - } else { - return stmt.getResultingDataType(ctx) + ":" + (stmt.isLiteral() ? "L:" + stmt.getLiteral() : "V") + "[" + stmt.refCtr + "](0)" + stmt.getMeta("_tempName") + ";"; - } - } - - public static List mergeSubtreeCombinations(RewriterStatement stmt, List indices, List> mList, final RuleContext ctx, int maximumCombinations) { - if (indices.isEmpty()) - return List.of(stmt); - - List mergedTreeCombinations = new ArrayList<>(); - cartesianProduct(mList, new RewriterStatement[mList.size()], stack -> { - RewriterStatement cpy = stmt.copyNode(); - for (int i = 0; i < stack.length; i++) - cpy.getOperands().set(indices.get(i), stack[i]); - cpy.consolidate(ctx); - cpy.prepareForHashing(); - cpy.recomputeHashCodes(ctx); - mergedTreeCombinations.add(cpy); - return mergedTreeCombinations.size() < maximumCombinations; - }); - - return mergedTreeCombinations; - } - - public static List generateSubtrees(RewriterStatement stmt, final RuleContext ctx, int maximumCombinations) { - List l = generateSubtrees(stmt, new HashMap<>(), ctx, maximumCombinations); - - if (ctx.metaPropagator != null) - l.forEach(subtree -> ctx.metaPropagator.apply(subtree)); - - return l.stream().map(subtree -> { - if (ctx.metaPropagator != null) - subtree = ctx.metaPropagator.apply(subtree); - - subtree.prepareForHashing(); - subtree.recomputeHashCodes(ctx); - return subtree; - }).collect(Collectors.toList()); - } - - private static Random rd = new Random(); - - private static List generateSubtrees(RewriterStatement stmt, Map> visited, final RuleContext ctx, int maxCombinations) { - if (stmt == null) - return Collections.emptyList(); - - RewriterStatement is = stmt; - List alreadyVisited = visited.get(is); - - if (alreadyVisited != null) - return alreadyVisited; - - if (stmt.getOperands().size() == 0) - return List.of(stmt); - - // Scan if operand is not a DataType - List indices = new ArrayList<>(); - for (int i = 0; i < stmt.getOperands().size(); i++) { - if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral()) - indices.add(i); - } - - int n = indices.size(); - int totalSubsets = 1 << n; - - List mList = new ArrayList<>(); - - visited.put(is, mList); - - List> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList()); - List out = new ArrayList<>(); - - for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) { - List> mOptionCpy = new ArrayList<>(mOptions); - - for (int i = 0; i < n; i++) { - // Check if the i-th child is included in the current subset - if ((subsetMask & (1 << i)) == 0) { - String dt = stmt.getOperands().get(indices.get(i)).getResultingDataType(ctx); - String namePrefix = "tmp"; - if (dt.equals("MATRIX")) - namePrefix = "M"; - else if (dt.equals("FLOAT")) - namePrefix = "f"; - else if (dt.equals("INT")) - namePrefix = "i"; - else if (dt.equals("BOOL")) - namePrefix = "b"; - RewriterDataType mT = new RewriterDataType().as(namePrefix + rd.nextInt(100000)).ofType(dt); - mT.consolidate(ctx); - mOptionCpy.set(i, List.of(mT)); - } - } - - out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations)); - if (out.size() > maxCombinations) { - System.out.println("Aborting early due to too many combinations"); - return out; - } - } - - return out; - } - public static RuleContext buildDefaultContext() { RuleContext ctx = RewriterContextSettings.getDefaultContext(new Random()); ctx.metaPropagator = new MetaPropagator(ctx); @@ -1430,26 +754,6 @@ public static Function unfuseOperators(fin return lastUnfuse; } - private static RuleContext lastSparsityCtx; - private static Function lastPrepareForSparsity; - - @Deprecated - public static Function prepareForSparsityEstimation(final RuleContext ctx) { - if (lastSparsityCtx == ctx) - return lastPrepareForSparsity; - - ArrayList mRules = new ArrayList<>(); - RewriterRuleCollection.substituteFusedOps(mRules, ctx); - RewriterRuleCollection.substituteEquivalentStatements(mRules, ctx); - RewriterRuleCollection.eliminateMultipleCasts(mRules, ctx); - RewriterRuleCollection.canonicalizeBooleanStatements(mRules, ctx); - RewriterRuleCollection.canonicalizeAlgebraicStatements(mRules, ctx); - RewriterHeuristic heur = new RewriterHeuristic(new RewriterRuleSet(ctx, mRules)); - lastSparsityCtx = ctx; - lastPrepareForSparsity = heur::apply; - return lastPrepareForSparsity; - } - public static Function buildCanonicalFormConverter(final RuleContext ctx, boolean debug) { ArrayList algebraicCanonicalizationRules = new ArrayList<>(); RewriterRuleCollection.substituteEquivalentStatements(algebraicCanonicalizationRules, ctx); @@ -1551,11 +855,8 @@ public static Function buildCanonicalFormC if (debug) System.out.println("PRE1: " + stmt.toParsableString(ctx, false)); - stmt.compress(); + stmt.compress(); // To remove unnecessary metadata such as assertions that are not encoded in the graph TopologicalSort.sort(stmt, ctx); - // Somehow it is unstable if we only compress and sort once - //stmt.compress(); - //TopologicalSort.sort(stmt, ctx); if (debug) System.out.println("FINAL1: " + stmt.toParsableString(ctx, false)); @@ -1668,8 +969,6 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule idxExpr.getOperands().set(1, argList.get(0)); } - //toRemove.add(sum); - RewriterStatement outerSum = RewriterStatement.multiArgInstr(ctx, "+", toRemove.toArray(RewriterStatement[]::new)); List mul = new ArrayList<>(); @@ -1681,7 +980,6 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule mul.add(outerSum); RewriterStatement mulStmt = RewriterStatement.multiArgInstr(ctx, "*", mul.toArray(RewriterStatement[]::new)); - //mul.add(sum); return RewriterStatement.multiArgInstr(ctx, "+", mulStmt, sum); } @@ -1812,12 +1110,7 @@ private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final return argList.get(0); else if (argList.isEmpty()) return neutral; - } /*else if (argList.size() == 2 && ConstantFoldingFunctions.isNegNeutral(argList.get(literals[0]).getLiteral(), stmt.trueInstruction())) { - RewriterStatement neutral = argList.get(literals[0]); - argList.remove(literals[0]); - - return new RewriterInstruction("-", ctx, argList.get(0)); - }*/ + } } if (literals.length < 2) @@ -2006,10 +1299,6 @@ private static void postCleanupIndexExpr(RewriterStatement cur) { } } - public static RewriterStatement doCSE(RewriterStatement stmt, final RuleContext ctx) { - throw new NotImplementedException(); - } - public static void renameIllegalVarnames(final RuleContext ctx, RewriterStatement... stmts) { MutableInt matrixVarCtr = new MutableInt(0); MutableInt scalarVarCtr = new MutableInt(0); diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java index 4428ea5cdfa..9aafc4999ae 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java @@ -125,7 +125,7 @@ public static void testExpressionClustering() { // return; // Skip // First, build all possible subtrees //System.out.println("Eval:\n" + expr.toParsableString(ctx, true)); - List subExprs = RewriterUtils.generateSubtrees(expr, ctx, pruneDataSubexrBiggerThan); + List subExprs = RewriterSearchUtils.generateSubtrees(expr, ctx, pruneDataSubexrBiggerThan); if (subExprs.size() > pruneDataSubexrBiggerThan) System.out.println("Critical number of subtrees: " + subExprs.size()); if (subExprs.size() > 2 * pruneDataSubexrBiggerThan) { diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java index b15522b9507..9d3426bfe28 100644 --- a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java @@ -20,6 +20,7 @@ package org.apache.sysds.test.component.codegen.rewrite.functions; import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterSearchUtils; import org.apache.sysds.hops.rewriter.utils.RewriterUtils; import org.apache.sysds.hops.rewriter.RuleContext; import org.junit.BeforeClass; @@ -40,7 +41,7 @@ public static void setup() { @Test public void test1() { RewriterStatement stmt = RewriterUtils.parse("+(1, a)", ctx, "LITERAL_INT:1", "FLOAT:a"); - List subtrees = RewriterUtils.generateSubtrees(stmt, ctx, 100); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); for (RewriterStatement sub : subtrees) { System.out.println("=========="); @@ -53,7 +54,7 @@ public void test1() { @Test public void test2() { RewriterStatement stmt = RewriterUtils.parse("+(+(1, b), a)", ctx, "LITERAL_INT:1", "FLOAT:a,b"); - List subtrees = RewriterUtils.generateSubtrees(stmt, ctx, 100); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); for (RewriterStatement sub : subtrees) { System.out.println("=========="); @@ -66,7 +67,7 @@ public void test2() { @Test public void test3() { RewriterStatement stmt = RewriterUtils.parse("-(+(1.0,A),B)", ctx, "LITERAL_FLOAT:1.0", "MATRIX:A,B"); - List subtrees = RewriterUtils.generateSubtrees(stmt, ctx, 100); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); for (RewriterStatement sub : subtrees) { System.out.println("==========");