From b40afa7a692aa4af95632dd0dcc6e6672eb139d6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 02:47:35 -0500 Subject: [PATCH 01/20] CompileTime: fix reshape elementwise --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 28 ++++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index d48ad98ea..fe6085a07 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7507,12 +7507,28 @@ struct ReshapeElementwise final : OpRewritePattern { SmallVector ops; for (auto v : elem->getOperands()) { - ops.push_back(rewriter.create( - op.getLoc(), - RankedTensorType::get( - op.getType().getShape(), - cast(v.getType()).getElementType()), - v)); + auto NT = RankedTensorType::get( + op.getType().getShape(), + cast(v.getType()).getElementType()); + Value reshaped = nullptr; + for (auto u : v.getUsers()) { + auto re = dyn_cast(u); + if (!re) + continue; + if (re.getType() != NT) + continue; + reshaped = re; + break; + } + if (!reshaped) { + reshaped = rewriter.create( + op.getLoc(), + RankedTensorType::get( + op.getType().getShape(), + cast(v.getType()).getElementType()), + v); + } + ops.push_back(reshaped); } auto newOp = rewriter.create( elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), From 4715ee0d8d37ce498b26d99d760f7d9c050563a8 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 02:59:50 -0500 Subject: [PATCH 02/20] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index fe6085a07..55c481723 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7510,7 +7510,7 @@ struct ReshapeElementwise final : OpRewritePattern { auto NT = RankedTensorType::get( op.getType().getShape(), cast(v.getType()).getElementType()); - Value reshaped = nullptr; + stablehlo::ReshapeOp reshaped = nullptr; for (auto u : v.getUsers()) { auto re = dyn_cast(u); if (!re) @@ -7521,12 +7521,16 @@ struct ReshapeElementwise final : OpRewritePattern { break; } if (!reshaped) { - reshaped = rewriter.create( - op.getLoc(), - RankedTensorType::get( - op.getType().getShape(), - cast(v.getType()).getElementType()), - v); + reshaped = rewriter.create(op.getLoc(), NT, v); + } else { + if (reshaped->getBlock() == op->getBlock()) { + if (op->isBeforeInBlock(reshaped)) { + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveBefore(op); }); + } + } else { + reshaped = rewriter.create(op.getLoc(), NT, v); + } } ops.push_back(reshaped); } From 6144bd4bd882552705ae09743fc1657848928e38 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 03:15:30 -0500 Subject: [PATCH 03/20] Also broadcast2reshape --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 33 ++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 55c481723..beaa00079 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3787,9 +3787,36 @@ struct BroadcastToReshape final // replace with reshape if (op.getType() == op.getOperand().getType()) rewriter.replaceOp(op, op.getOperand()); - else + else { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getOperand()); + auto NT = op.getType(); + stablehlo::ReshapeOp reshaped = nullptr; + for (auto u : op.getOperand().getUsers()) { + auto re = dyn_cast(u); + if (!re) + continue; + if (re.getType() != NT) + continue; + reshaped = re; + break; + } + if (!reshaped) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); + } else { + if (reshaped->getBlock() == op->getBlock()) { + if (op->isBeforeInBlock(reshaped)) { + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveBefore(op); }); + } + rewriter.replaceOp(op, reshaped); + } else { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getOperand()); + } + } + } return success(); } }; @@ -7495,6 +7522,10 @@ struct ReshapeElementwise final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, PatternRewriter &rewriter) const override { + if (op.getType() == op.getOperand.getType()) { + rewriter.replaceOp(op, op.getOperand()); + return success(); + } auto elem = op.getOperand().getDefiningOp(); if (!elem) return failure(); From 83e37fdb0d1da9c42e0e88e317e96e65a841f5ea Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 11:30:23 -0500 Subject: [PATCH 04/20] fmt --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 46 +++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index beaa00079..4d306cc10 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3790,33 +3790,33 @@ struct BroadcastToReshape final else { rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); - auto NT = op.getType(); - stablehlo::ReshapeOp reshaped = nullptr; - for (auto u : op.getOperand().getUsers()) { - auto re = dyn_cast(u); - if (!re) - continue; - if (re.getType() != NT) - continue; - reshaped = re; - break; - } - if (!reshaped) { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getOperand()); - } else { - if (reshaped->getBlock() == op->getBlock()) { - if (op->isBeforeInBlock(reshaped)) { - rewriter.modifyOpInPlace(reshaped, - [&]() { reshaped->moveBefore(op); }); - } - rewriter.replaceOp(op, reshaped); - } else { + auto NT = op.getType(); + stablehlo::ReshapeOp reshaped = nullptr; + for (auto u : op.getOperand().getUsers()) { + auto re = dyn_cast(u); + if (!re) + continue; + if (re.getType() != NT) + continue; + reshaped = re; + break; + } + if (!reshaped) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); + } else { + if (reshaped->getBlock() == op->getBlock()) { + if (op->isBeforeInBlock(reshaped)) { + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveBefore(op); }); + } + rewriter.replaceOp(op, reshaped); + } else { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getOperand()); + } } } - } return success(); } }; From c14add7089509f7f70e99b0113f76f6625eb42d2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 11:31:32 -0500 Subject: [PATCH 05/20] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 4d306cc10..25dafd970 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7522,7 +7522,7 @@ struct ReshapeElementwise final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, PatternRewriter &rewriter) const override { - if (op.getType() == op.getOperand.getType()) { + if (op.getType() == op.getOperand().getType()) { rewriter.replaceOp(op, op.getOperand()); return success(); } From 29a187c2b9dfd613f67a7ae66ead316e96b2e06d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 11:36:43 -0500 Subject: [PATCH 06/20] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 25dafd970..aeddd6231 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3788,8 +3788,6 @@ struct BroadcastToReshape final if (op.getType() == op.getOperand().getType()) rewriter.replaceOp(op, op.getOperand()); else { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getOperand()); auto NT = op.getType(); stablehlo::ReshapeOp reshaped = nullptr; for (auto u : op.getOperand().getUsers()) { From 33a7afdda65f029fd0a42a315dfef006a51348d6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:12:46 -0500 Subject: [PATCH 07/20] fast dominator --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 140 +++++++++++++++--- .../raising/affine_to_stablehlo13.mlir | 6 +- .../raising/affine_to_stablehlo15.mlir | 6 +- .../raising/affine_to_stablehlo_pforred.mlir | 12 +- .../raising/affine_to_stablehlo_pforred2.mlir | 13 +- 5 files changed, 141 insertions(+), 36 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index aeddd6231..cefc1cd31 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3746,6 +3746,66 @@ struct BroadcastReshape final } }; +// Returns legal, and if reshaped comes before op +std::pair fastDoesADominateB(Operation* reshaped, Operation* op, Value v) { + assert(reshaped); + assert(op); + size_t limit = 1000; + if (reshaped->getBlock() == op->getBlock()) { + if (op->getBlock()->isOpOrderValid() || op->getBlock()->getOperations().size() <= limit) { + return std::make_pair(true, reshaped->isBeforeInBlock(op)); + } + } + if (v) + if (auto pred = v.getDefiningOp()) { + bool seenReshape = false; + bool seenUser = false; + Operation* cur = pred->getNextNode(); + for (int i=0; cur && iisAncestor(reshaped)) { + seenReshape = true; + } + if (cur->isAncestor(op)) { + seenUser = true; + } + if (seenReshape || seenUser) break; + } + if (seenReshape && !seenUser) { + return std::make_pair(true, true); + } + if (!seenReshape && seenUser) { + return std::make_pair(true, false); + } + } + { + bool seenUser = false; + Operation* cur = reshaped->getNextNode(); + for (int i=0; cur && iisAncestor(op)) { + seenUser = true; + return std::make_pair(true, true); + } + } + if (!cur) { + std::make_pair(true, false); + } + } + { + bool seenReshape = false; + Operation* cur = op->getNextNode(); + for (int i=0; cur && iisAncestor(reshaped)) { + seenReshape = true; + return std::make_pair(true, false); + } + } + if (!cur) { + std::make_pair(true, true); + } + } + return std::make_pair(false, false); +} + struct BroadcastToReshape final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3790,29 +3850,36 @@ struct BroadcastToReshape final else { auto NT = op.getType(); stablehlo::ReshapeOp reshaped = nullptr; + bool before = false; for (auto u : op.getOperand().getUsers()) { auto re = dyn_cast(u); if (!re) continue; if (re.getType() != NT) continue; + auto &&[legal, before2] = fastDoesADominateB(op, re, op.getOperand()); + if (!legal) continue; + before = before2; reshaped = re; break; } if (!reshaped) { + //llvm::errs() << " replaced to reshape: " << op << "\n"; + if (auto rop = op.getOperand().getDefiningOp()) { + rewriter.setInsertionPointAfter(rop); + } else if (auto ba = dyn_cast(op.getOperand())) { + rewriter.setInsertionPointToStart(ba.getOwner()); + } rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); } else { - if (reshaped->getBlock() == op->getBlock()) { - if (op->isBeforeInBlock(reshaped)) { + if (before) { + // llvm::errs() << " moved reshape: " << reshaped << "\n"; rewriter.modifyOpInPlace(reshaped, [&]() { reshaped->moveBefore(op); }); } + //llvm::errs() << " replaced op with reshape: " << op << " " << reshaped << "\n"; rewriter.replaceOp(op, reshaped); - } else { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getOperand()); - } } } return success(); @@ -7528,12 +7595,20 @@ struct ReshapeElementwise final : OpRewritePattern { if (!elem) return failure(); - if (onlySingleUser && !llvm::hasSingleElement(elem->getUsers())) + bool singleUse = llvm::hasSingleElement(elem->getUsers()); + if (onlySingleUser && !singleUse) return failure(); if (!elem->hasTrait()) return failure(); + if (singleUse) { + auto pt = rewriter.getInsertionPoint (); + pt--; + rewriter.setInsertionPoint(rewriter.getInsertionBlock(), pt); + } + + //llvm::errs() << " reshaping " << *elem << " reshape: " << op << "\n"; SmallVector ops; for (auto v : elem->getOperands()) { auto NT = RankedTensorType::get( @@ -7550,23 +7625,38 @@ struct ReshapeElementwise final : OpRewritePattern { break; } if (!reshaped) { + // llvm::errs() << " creating new reshape of arg " << v << "\n"; reshaped = rewriter.create(op.getLoc(), NT, v); } else { - if (reshaped->getBlock() == op->getBlock()) { - if (op->isBeforeInBlock(reshaped)) { + auto &&[legal, before] = fastDoesADominateB(op, reshaped, v); + if (legal) { + if (before) { + //llvm::errs() << " moved reshape " << reshaped << " of arg " << v << "\n"; rewriter.modifyOpInPlace(reshaped, [&]() { reshaped->moveBefore(op); }); } } else { + // llvm::errs() << " non block reshape reshape " << reshaped << " of arg " << v << "\n"; reshaped = rewriter.create(op.getLoc(), NT, v); } } ops.push_back(reshaped); } + + if (singleUse) { + //llvm::errs() << " modifying in place\n"; + rewriter.modifyOpInPlace(elem, [&]() { + elem->setOperands(ops); + elem->getResult(0).setType(op.getType()); + }); + rewriter.replaceOp(op, elem); + } else { auto newOp = rewriter.create( elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), TypeRange(op.getType()), elem->getAttrs(), {}, {}); + //llvm::errs() << " created reshaped elem: " << newOp << "\n"; rewriter.replaceOp(op, newOp); + } return success(); } }; @@ -7769,13 +7859,16 @@ template struct CSE final : OpRewritePattern { continue; if (nop->getBlock() != op->getBlock()) continue; - if (nop->isBeforeInBlock(op)) { - rewriter.replaceOp(op, nop); - return success(); - } else { - rewriter.replaceOp(nop, op); - return success(); - } + auto &&[legal, before] = fastDoesADominateB(nop, op, nullptr); + if (legal) { + if (before) { + rewriter.replaceOp(op, nop); + return success(); + } else { + rewriter.replaceOp(nop, op); + return success(); + } + } } return failure(); } @@ -12470,7 +12563,9 @@ struct CommonCompareExpressionRewrite continue; if (userCompareOp.getLhs() == lhs && userCompareOp.getRhs() == rhs) { - if (user->isBeforeInBlock(op)) { + auto &&[legal, before] = fastDoesADominateB(user, op, opOperand); + if (legal) { + if (before) { auto negatedCondition = rewriter.create( op.getLoc(), userCompareOp.getResult()); rewriter.replaceOp(op, negatedCondition); @@ -12481,6 +12576,7 @@ struct CommonCompareExpressionRewrite rewriter.replaceOp(user, negatedCondition); return success(); } + } } } } @@ -14377,6 +14473,16 @@ struct EnzymeHLOOptPass GreedyRewriteConfig config; config.maxIterations = max_iterations; config.useTopDownTraversal = top_down; + getOperation()->walk([](Operation* op) { + for (auto ®ion : op->getRegions()) { + for (auto &blk : region.getBlocks()) { + + if (!blk.isOpOrderValid()) { + blk.recomputeOpOrder(); + } + } + } + }); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { signalPassFailure(); diff --git a/test/lit_tests/raising/affine_to_stablehlo13.mlir b/test/lit_tests/raising/affine_to_stablehlo13.mlir index 5d9aa9bf1..d442dd316 100644 --- a/test/lit_tests/raising/affine_to_stablehlo13.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo13.mlir @@ -94,9 +94,9 @@ module { } } // CHECK: func.func private @repeat_iv_raised(%arg0: tensor<10xi64>, %arg1: tensor<10xi64>, %arg2: tensor<10x10xf64>, %arg3: tensor<10xf64>) -> (tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64>) { -// CHECK-NEXT: %0 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64> -// CHECK-NEXT: %1 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64> -// CHECK-NEXT: %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64> +// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64> +// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64> +// CHECK-NEXT: %2 = stablehlo.concatenate %1, %0, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64> // CHECK-NEXT: %3 = "stablehlo.gather"(%arg2, %2) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<10x10xf64>, tensor<10x2xi64>) -> tensor<10xf64> // CHECK-NEXT: return %arg0, %arg1, %arg2, %3 : tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64> // CHECK-NEXT: } diff --git a/test/lit_tests/raising/affine_to_stablehlo15.mlir b/test/lit_tests/raising/affine_to_stablehlo15.mlir index f116fe23f..f0db5de42 100644 --- a/test/lit_tests/raising/affine_to_stablehlo15.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo15.mlir @@ -25,9 +25,9 @@ module { // CHECK-NEXT: %1 = stablehlo.dynamic_slice %arg0, %iterArg, %c_1, sizes = [1, 10] : (tensor<4x10xf32>, tensor, tensor) -> tensor<1x10xf32> // CHECK-NEXT: %2 = stablehlo.reshape %1 : (tensor<1x10xf32>) -> tensor<10xf32> // CHECK-NEXT: %3 = arith.mulf %2, %2 : tensor<10xf32> -// CHECK-NEXT: %4 = stablehlo.multiply %iterArg, %c_0 : tensor -// CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32> -// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %5, %4, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor, tensor) -> tensor<16x10xf32> +// CHECK-NEXT: %4 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32> +// CHECK-NEXT: %5 = stablehlo.multiply %iterArg, %c_0 : tensor +// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %4, %5, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor, tensor) -> tensor<16x10xf32> // CHECK-NEXT: %7 = stablehlo.add %iterArg, %c : tensor // CHECK-NEXT: stablehlo.return %7, %6 : tensor, tensor<16x10xf32> // CHECK-NEXT: } diff --git a/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir b/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir index 6dcad5808..b3756bca8 100644 --- a/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo_pforred.mlir @@ -56,10 +56,10 @@ module @"reactant_loop!" attributes {mhlo.num_partitions = 1 : i64, mhlo.num_rep // CHECK-NEXT: %15 = stablehlo.reduce(%12 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor) -> tensor<20x45xf64> // CHECK-NEXT: %16 = stablehlo.reduce(%14 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor) -> tensor<20x45xf64> // CHECK-NEXT: %17 = arith.addf %5, %15 {fastmathFlags = #llvm.fastmath} : tensor<20x45xf64> -// CHECK-NEXT: %18 = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath} : tensor<20x45xf64> -// CHECK-NEXT: %19 = stablehlo.reshape %18 : (tensor<20x45xf64>) -> tensor<1x20x45xf64> -// CHECK-NEXT: %20 = stablehlo.dynamic_update_slice %arg1, %19, %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x35x59xf64> -// CHECK-NEXT: %21 = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64> -// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %21, %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x34x59xf64> -// CHECK-NEXT: return %22, %20, %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64> +// CHECK-NEXT: %[[i21:.+]] = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64> +// CHECK-NEXT: %[[i18:.+]] = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath} : tensor<20x45xf64> +// CHECK-NEXT: %[[i19:.+]] = stablehlo.reshape %[[i18]] : (tensor<20x45xf64>) -> tensor<1x20x45xf64> +// CHECK-NEXT: %[[i20:.+]] = stablehlo.dynamic_update_slice %arg1, %[[i19]], %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x35x59xf64> +// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %[[i21]], %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor, tensor, tensor) -> tensor<1x34x59xf64> +// CHECK-NEXT: return %22, %[[i20]], %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64> // CHECK-NEXT: } diff --git a/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir b/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir index bdcbfc0eb..51fb3333d 100644 --- a/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir +++ b/test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir @@ -119,11 +119,10 @@ func.func private @"##call__Z29gpu__compute_barotropic_mode_16CompilerMetadataI1 // CHECK-NEXT: %50 = stablehlo.reduce(%44 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor) -> tensor<96x192xf64> // CHECK-NEXT: %51 = stablehlo.reduce(%49 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor) -> tensor<96x192xf64> // CHECK-NEXT: %52 = arith.addf %33, %50 {fastmathFlags = #llvm.fastmath} : tensor<96x192xf64> -// CHECK-NEXT: %53 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath} : tensor<96x192xf64> -// CHECK-NEXT: %54 = stablehlo.reshape %53 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> -// CHECK-NEXT: %55 = stablehlo.dynamic_update_slice %arg1, %54, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> -// CHECK-NEXT: %56 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> -// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %56, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> -// CHECK-NEXT: return %57, %55, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64> +// CHECK-NEXT: %53 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> +// CHECK-NEXT: %54 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath} : tensor<96x192xf64> +// CHECK-NEXT: %55 = stablehlo.reshape %54 : (tensor<96x192xf64>) -> tensor<1x96x192xf64> +// CHECK-NEXT: %56 = stablehlo.dynamic_update_slice %arg1, %55, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> +// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %53, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor, tensor, tensor) -> tensor<1x140x206xf64> +// CHECK-NEXT: return %57, %56, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64> // CHECK-NEXT: } - From e5591d7627e717ae55e8f5fba8e5b6fb589cacef Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:12:53 -0500 Subject: [PATCH 08/20] fmt --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 233 +++++++++++----------- 1 file changed, 120 insertions(+), 113 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index cefc1cd31..2e212cca6 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3747,63 +3747,66 @@ struct BroadcastReshape final }; // Returns legal, and if reshaped comes before op -std::pair fastDoesADominateB(Operation* reshaped, Operation* op, Value v) { - assert(reshaped); - assert(op); - size_t limit = 1000; - if (reshaped->getBlock() == op->getBlock()) { - if (op->getBlock()->isOpOrderValid() || op->getBlock()->getOperations().size() <= limit) { - return std::make_pair(true, reshaped->isBeforeInBlock(op)); - } - } - if (v) - if (auto pred = v.getDefiningOp()) { - bool seenReshape = false; - bool seenUser = false; - Operation* cur = pred->getNextNode(); - for (int i=0; cur && iisAncestor(reshaped)) { - seenReshape = true; - } - if (cur->isAncestor(op)) { - seenUser = true; - } - if (seenReshape || seenUser) break; - } - if (seenReshape && !seenUser) { - return std::make_pair(true, true); - } - if (!seenReshape && seenUser) { - return std::make_pair(true, false); - } - } - { - bool seenUser = false; - Operation* cur = reshaped->getNextNode(); - for (int i=0; cur && iisAncestor(op)) { - seenUser = true; - return std::make_pair(true, true); - } - } - if (!cur) { - std::make_pair(true, false); - } - } - { - bool seenReshape = false; - Operation* cur = op->getNextNode(); - for (int i=0; cur && iisAncestor(reshaped)) { - seenReshape = true; - return std::make_pair(true, false); - } - } - if (!cur) { - std::make_pair(true, true); - } - } - return std::make_pair(false, false); +std::pair fastDoesADominateB(Operation *reshaped, Operation *op, + Value v) { + assert(reshaped); + assert(op); + size_t limit = 1000; + if (reshaped->getBlock() == op->getBlock()) { + if (op->getBlock()->isOpOrderValid() || + op->getBlock()->getOperations().size() <= limit) { + return std::make_pair(true, reshaped->isBeforeInBlock(op)); + } + } + if (v) + if (auto pred = v.getDefiningOp()) { + bool seenReshape = false; + bool seenUser = false; + Operation *cur = pred->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + if (cur->isAncestor(reshaped)) { + seenReshape = true; + } + if (cur->isAncestor(op)) { + seenUser = true; + } + if (seenReshape || seenUser) + break; + } + if (seenReshape && !seenUser) { + return std::make_pair(true, true); + } + if (!seenReshape && seenUser) { + return std::make_pair(true, false); + } + } + { + bool seenUser = false; + Operation *cur = reshaped->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + if (cur->isAncestor(op)) { + seenUser = true; + return std::make_pair(true, true); + } + } + if (!cur) { + std::make_pair(true, false); + } + } + { + bool seenReshape = false; + Operation *cur = op->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + if (cur->isAncestor(reshaped)) { + seenReshape = true; + return std::make_pair(true, false); + } + } + if (!cur) { + std::make_pair(true, true); + } + } + return std::make_pair(false, false); } struct BroadcastToReshape final @@ -3857,29 +3860,31 @@ struct BroadcastToReshape final continue; if (re.getType() != NT) continue; - auto &&[legal, before2] = fastDoesADominateB(op, re, op.getOperand()); - if (!legal) continue; - before = before2; + auto &&[legal, before2] = fastDoesADominateB(op, re, op.getOperand()); + if (!legal) + continue; + before = before2; reshaped = re; break; } if (!reshaped) { - //llvm::errs() << " replaced to reshape: " << op << "\n"; - if (auto rop = op.getOperand().getDefiningOp()) { - rewriter.setInsertionPointAfter(rop); - } else if (auto ba = dyn_cast(op.getOperand())) { - rewriter.setInsertionPointToStart(ba.getOwner()); - } + // llvm::errs() << " replaced to reshape: " << op << "\n"; + if (auto rop = op.getOperand().getDefiningOp()) { + rewriter.setInsertionPointAfter(rop); + } else if (auto ba = dyn_cast(op.getOperand())) { + rewriter.setInsertionPointToStart(ba.getOwner()); + } rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); } else { - if (before) { - // llvm::errs() << " moved reshape: " << reshaped << "\n"; - rewriter.modifyOpInPlace(reshaped, - [&]() { reshaped->moveBefore(op); }); - } - //llvm::errs() << " replaced op with reshape: " << op << " " << reshaped << "\n"; - rewriter.replaceOp(op, reshaped); + if (before) { + // llvm::errs() << " moved reshape: " << reshaped << "\n"; + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveBefore(op); }); + } + // llvm::errs() << " replaced op with reshape: " << op << " " << + // reshaped << "\n"; + rewriter.replaceOp(op, reshaped); } } return success(); @@ -7603,12 +7608,12 @@ struct ReshapeElementwise final : OpRewritePattern { return failure(); if (singleUse) { - auto pt = rewriter.getInsertionPoint (); + auto pt = rewriter.getInsertionPoint(); pt--; rewriter.setInsertionPoint(rewriter.getInsertionBlock(), pt); } - //llvm::errs() << " reshaping " << *elem << " reshape: " << op << "\n"; + // llvm::errs() << " reshaping " << *elem << " reshape: " << op << "\n"; SmallVector ops; for (auto v : elem->getOperands()) { auto NT = RankedTensorType::get( @@ -7625,18 +7630,20 @@ struct ReshapeElementwise final : OpRewritePattern { break; } if (!reshaped) { - // llvm::errs() << " creating new reshape of arg " << v << "\n"; + // llvm::errs() << " creating new reshape of arg " << v << "\n"; reshaped = rewriter.create(op.getLoc(), NT, v); } else { - auto &&[legal, before] = fastDoesADominateB(op, reshaped, v); + auto &&[legal, before] = fastDoesADominateB(op, reshaped, v); if (legal) { if (before) { - //llvm::errs() << " moved reshape " << reshaped << " of arg " << v << "\n"; + // llvm::errs() << " moved reshape " << reshaped << " of arg " << v + // << "\n"; rewriter.modifyOpInPlace(reshaped, [&]() { reshaped->moveBefore(op); }); } } else { - // llvm::errs() << " non block reshape reshape " << reshaped << " of arg " << v << "\n"; + // llvm::errs() << " non block reshape reshape " << reshaped << " of + // arg " << v << "\n"; reshaped = rewriter.create(op.getLoc(), NT, v); } } @@ -7644,18 +7651,18 @@ struct ReshapeElementwise final : OpRewritePattern { } if (singleUse) { - //llvm::errs() << " modifying in place\n"; + // llvm::errs() << " modifying in place\n"; rewriter.modifyOpInPlace(elem, [&]() { - elem->setOperands(ops); - elem->getResult(0).setType(op.getType()); + elem->setOperands(ops); + elem->getResult(0).setType(op.getType()); }); rewriter.replaceOp(op, elem); } else { - auto newOp = rewriter.create( - elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), - TypeRange(op.getType()), elem->getAttrs(), {}, {}); - //llvm::errs() << " created reshaped elem: " << newOp << "\n"; - rewriter.replaceOp(op, newOp); + auto newOp = rewriter.create( + elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), + TypeRange(op.getType()), elem->getAttrs(), {}, {}); + // llvm::errs() << " created reshaped elem: " << newOp << "\n"; + rewriter.replaceOp(op, newOp); } return success(); } @@ -7859,16 +7866,16 @@ template struct CSE final : OpRewritePattern { continue; if (nop->getBlock() != op->getBlock()) continue; - auto &&[legal, before] = fastDoesADominateB(nop, op, nullptr); - if (legal) { - if (before) { + auto &&[legal, before] = fastDoesADominateB(nop, op, nullptr); + if (legal) { + if (before) { rewriter.replaceOp(op, nop); return success(); - } else { + } else { rewriter.replaceOp(nop, op); return success(); - } - } + } + } } return failure(); } @@ -12563,20 +12570,20 @@ struct CommonCompareExpressionRewrite continue; if (userCompareOp.getLhs() == lhs && userCompareOp.getRhs() == rhs) { - auto &&[legal, before] = fastDoesADominateB(user, op, opOperand); - if (legal) { - if (before) { - auto negatedCondition = rewriter.create( - op.getLoc(), userCompareOp.getResult()); - rewriter.replaceOp(op, negatedCondition); - return success(); - } else { - auto negatedCondition = rewriter.create( - userCompareOp.getLoc(), op.getResult()); - rewriter.replaceOp(user, negatedCondition); - return success(); + auto &&[legal, before] = fastDoesADominateB(user, op, opOperand); + if (legal) { + if (before) { + auto negatedCondition = rewriter.create( + op.getLoc(), userCompareOp.getResult()); + rewriter.replaceOp(op, negatedCondition); + return success(); + } else { + auto negatedCondition = rewriter.create( + userCompareOp.getLoc(), op.getResult()); + rewriter.replaceOp(user, negatedCondition); + return success(); + } } - } } } } @@ -14473,16 +14480,16 @@ struct EnzymeHLOOptPass GreedyRewriteConfig config; config.maxIterations = max_iterations; config.useTopDownTraversal = top_down; - getOperation()->walk([](Operation* op) { + getOperation()->walk([](Operation *op) { for (auto ®ion : op->getRegions()) { for (auto &blk : region.getBlocks()) { - - if (!blk.isOpOrderValid()) { - blk.recomputeOpOrder(); - } - } + + if (!blk.isOpOrderValid()) { + blk.recomputeOpOrder(); + } + } } - }); + }); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { signalPassFailure(); From ba2fac872b5884df15b759d02203284ab5d6a105 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:16:21 -0500 Subject: [PATCH 09/20] locality rewrites --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 2e212cca6..cc2aac532 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7644,6 +7644,11 @@ struct ReshapeElementwise final : OpRewritePattern { } else { // llvm::errs() << " non block reshape reshape " << reshaped << " of // arg " << v << "\n"; + if (auto rop = v.getDefiningOp()) { + rewriter.setInsertionPointAfter(rop); + } else if (auto ba = dyn_cast(v)) { + rewriter.setInsertionPointToStart(ba.getOwner()); + } reshaped = rewriter.create(op.getLoc(), NT, v); } } @@ -7658,6 +7663,7 @@ struct ReshapeElementwise final : OpRewritePattern { }); rewriter.replaceOp(op, elem); } else { + rewriter.setInsertionPointAfter(op); auto newOp = rewriter.create( elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), TypeRange(op.getType()), elem->getAttrs(), {}, {}); From 6735fea56baa2d8123347121705ed2e57bc48245 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:16:27 -0500 Subject: [PATCH 10/20] fmt --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index cc2aac532..2b3c23232 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7644,11 +7644,11 @@ struct ReshapeElementwise final : OpRewritePattern { } else { // llvm::errs() << " non block reshape reshape " << reshaped << " of // arg " << v << "\n"; - if (auto rop = v.getDefiningOp()) { - rewriter.setInsertionPointAfter(rop); - } else if (auto ba = dyn_cast(v)) { - rewriter.setInsertionPointToStart(ba.getOwner()); - } + if (auto rop = v.getDefiningOp()) { + rewriter.setInsertionPointAfter(rop); + } else if (auto ba = dyn_cast(v)) { + rewriter.setInsertionPointToStart(ba.getOwner()); + } reshaped = rewriter.create(op.getLoc(), NT, v); } } From 8a5ac408675ed8f956ce0eae9084780141e91ad3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:20:29 -0500 Subject: [PATCH 11/20] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 2b3c23232..2cb69e54a 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3772,6 +3772,7 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, } if (seenReshape || seenUser) break; + cur = cur->getNextNode(); } if (seenReshape && !seenUser) { return std::make_pair(true, true); @@ -3788,6 +3789,7 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, seenUser = true; return std::make_pair(true, true); } + cur = cur->getNextNode(); } if (!cur) { std::make_pair(true, false); @@ -3801,6 +3803,7 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, seenReshape = true; return std::make_pair(true, false); } + cur = cur->getNextNode(); } if (!cur) { std::make_pair(true, true); From 3ab9cd9242726a0c6fbd8e7bfac4b81ffbbaa891 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:20:36 -0500 Subject: [PATCH 12/20] fmt --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 2cb69e54a..b5ec90840 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3772,7 +3772,7 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, } if (seenReshape || seenUser) break; - cur = cur->getNextNode(); + cur = cur->getNextNode(); } if (seenReshape && !seenUser) { return std::make_pair(true, true); From b86707c4b58612c7c72d0a8a102d872da5e333a4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:26:16 -0500 Subject: [PATCH 13/20] smaller limit --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index b5ec90840..629f4c28a 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3751,7 +3751,7 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, Value v) { assert(reshaped); assert(op); - size_t limit = 1000; + size_t limit = 200; if (reshaped->getBlock() == op->getBlock()) { if (op->getBlock()->isOpOrderValid() || op->getBlock()->getOperations().size() <= limit) { From 15db95bb8ed07db6b993d986da4bfc422543f09f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:33:35 -0500 Subject: [PATCH 14/20] more asymptitic --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 629f4c28a..67fa6d225 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3753,8 +3753,10 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, assert(op); size_t limit = 200; if (reshaped->getBlock() == op->getBlock()) { - if (op->getBlock()->isOpOrderValid() || - op->getBlock()->getOperations().size() <= limit) { + + // TODO we could do the following, if size wasn't O(N) =/ + // op->getBlock()->getOperations().size() <= limit) { + if (op->getBlock()->isOpOrderValid()) { return std::make_pair(true, reshaped->isBeforeInBlock(op)); } } From 45dd839b7faf545085b93e0f24ffb3e03c8e5391 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:33:43 -0500 Subject: [PATCH 15/20] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 67fa6d225..966f9aeea 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3753,7 +3753,7 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, assert(op); size_t limit = 200; if (reshaped->getBlock() == op->getBlock()) { - + // TODO we could do the following, if size wasn't O(N) =/ // op->getBlock()->getOperations().size() <= limit) { if (op->getBlock()->isOpOrderValid()) { From a117f557dbbefb9c381c999790ffbcc6dd2687ef Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:38:57 -0500 Subject: [PATCH 16/20] just block --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 80 ++++++++++++----------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 966f9aeea..077f5a2bb 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3759,56 +3759,60 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, if (op->getBlock()->isOpOrderValid()) { return std::make_pair(true, reshaped->isBeforeInBlock(op)); } - } - if (v) - if (auto pred = v.getDefiningOp()) { - bool seenReshape = false; - bool seenUser = false; - Operation *cur = pred->getNextNode(); - for (int i = 0; cur && i < limit; i++) { - if (cur->isAncestor(reshaped)) { - seenReshape = true; + if (v) + if (auto pred = v.getDefiningOp()) { + bool seenReshape = false; + bool seenUser = false; + Operation *cur = pred->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + // TODO we could make this an isancestor query, but of course compile + // time if (cur->isAncestor(reshaped)) + if (cur == reshaped) + seenReshape = true; } - if (cur->isAncestor(op)) { + // if (cur->isAncestor(op)) + if (cur == op) { seenUser = true; } if (seenReshape || seenUser) break; cur = cur->getNextNode(); } - if (seenReshape && !seenUser) { - return std::make_pair(true, true); + if (seenReshape && !seenUser) { + return std::make_pair(true, true); + } + if (!seenReshape && seenUser) { + return std::make_pair(true, false); + } + { + bool seenUser = false; + Operation *cur = reshaped->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + // if (cur->isAncestor(op)) + if (cur == op) { + seenUser = true; + return std::make_pair(true, true); + } + cur = cur->getNextNode(); } - if (!seenReshape && seenUser) { - return std::make_pair(true, false); + if (!cur) { + std::make_pair(true, false); } } - { - bool seenUser = false; - Operation *cur = reshaped->getNextNode(); - for (int i = 0; cur && i < limit; i++) { - if (cur->isAncestor(op)) { - seenUser = true; - return std::make_pair(true, true); + { + bool seenReshape = false; + Operation *cur = op->getNextNode(); + for (int i = 0; cur && i < limit; i++) { + // if (cur->isAncestor(reshaped)) + if (cur == reshaped) { + seenReshape = true; + return std::make_pair(true, false); + } + cur = cur->getNextNode(); } - cur = cur->getNextNode(); - } - if (!cur) { - std::make_pair(true, false); - } - } - { - bool seenReshape = false; - Operation *cur = op->getNextNode(); - for (int i = 0; cur && i < limit; i++) { - if (cur->isAncestor(reshaped)) { - seenReshape = true; - return std::make_pair(true, false); + if (!cur) { + std::make_pair(true, true); } - cur = cur->getNextNode(); - } - if (!cur) { - std::make_pair(true, true); } } return std::make_pair(false, false); From d497844c620f0ac4f769b7ae61294a5d6ea6ac7f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:40:15 -0500 Subject: [PATCH 17/20] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 077f5a2bb..151fb296d 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3769,14 +3769,14 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, // time if (cur->isAncestor(reshaped)) if (cur == reshaped) seenReshape = true; + // if (cur->isAncestor(op)) + if (cur == op) { + seenUser = true; + } + if (seenReshape || seenUser) + break; + cur = cur->getNextNode(); } - // if (cur->isAncestor(op)) - if (cur == op) { - seenUser = true; - } - if (seenReshape || seenUser) - break; - cur = cur->getNextNode(); } if (seenReshape && !seenUser) { return std::make_pair(true, true); From d620b3df245829c29078541f731aea4ab34f5461 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:41:32 -0500 Subject: [PATCH 18/20] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 151fb296d..a923c5e63 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3777,13 +3777,13 @@ std::pair fastDoesADominateB(Operation *reshaped, Operation *op, break; cur = cur->getNextNode(); } + if (seenReshape && !seenUser) { + return std::make_pair(true, true); + } + if (!seenReshape && seenUser) { + return std::make_pair(true, false); + } } - if (seenReshape && !seenUser) { - return std::make_pair(true, true); - } - if (!seenReshape && seenUser) { - return std::make_pair(true, false); - } { bool seenUser = false; Operation *cur = reshaped->getNextNode(); From 1106dd021c8aa1ab15ce2c292999d606b76d1dc4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 13:46:13 -0500 Subject: [PATCH 19/20] multi single user --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index a923c5e63..24b869133 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7609,7 +7609,14 @@ struct ReshapeElementwise final : OpRewritePattern { if (!elem) return failure(); - bool singleUse = llvm::hasSingleElement(elem->getUsers()); + bool singleUse = true; + for (auto U : elem->getUsers()) { + if (U != op) { + singleUse = false; + break; + } + } + if (onlySingleUser && !singleUse) return failure(); From 0f7e89a992e4160ac74c6d3e6c65db2e9cf8c4d6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Apr 2025 14:18:49 -0500 Subject: [PATCH 20/20] best status atm --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 65 ++++++++++++----------- test/lit_tests/reshapeelementwise.mlir | 6 +-- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 24b869133..5913d9ad2 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3877,7 +3877,6 @@ struct BroadcastToReshape final break; } if (!reshaped) { - // llvm::errs() << " replaced to reshape: " << op << "\n"; if (auto rop = op.getOperand().getDefiningOp()) { rewriter.setInsertionPointAfter(rop); } else if (auto ba = dyn_cast(op.getOperand())) { @@ -3887,12 +3886,9 @@ struct BroadcastToReshape final op.getOperand()); } else { if (before) { - // llvm::errs() << " moved reshape: " << reshaped << "\n"; rewriter.modifyOpInPlace(reshaped, [&]() { reshaped->moveBefore(op); }); } - // llvm::errs() << " replaced op with reshape: " << op << " " << - // reshaped << "\n"; rewriter.replaceOp(op, reshaped); } } @@ -7609,82 +7605,87 @@ struct ReshapeElementwise final : OpRewritePattern { if (!elem) return failure(); + if (!elem->hasTrait()) + return failure(); + bool singleUse = true; + SmallVector toReplace; for (auto U : elem->getUsers()) { - if (U != op) { - singleUse = false; - break; + if (auto re = dyn_cast(U)) { + if (re.getType() == op.getType()) { + toReplace.push_back(re); + continue; + } } + singleUse = false; + break; } if (onlySingleUser && !singleUse) return failure(); - if (!elem->hasTrait()) - return failure(); - if (singleUse) { auto pt = rewriter.getInsertionPoint(); pt--; rewriter.setInsertionPoint(rewriter.getInsertionBlock(), pt); } - // llvm::errs() << " reshaping " << *elem << " reshape: " << op << "\n"; SmallVector ops; for (auto v : elem->getOperands()) { auto NT = RankedTensorType::get( op.getType().getShape(), cast(v.getType()).getElementType()); stablehlo::ReshapeOp reshaped = nullptr; + bool before; for (auto u : v.getUsers()) { auto re = dyn_cast(u); if (!re) continue; if (re.getType() != NT) continue; + auto &&[legal, before2] = fastDoesADominateB(elem, re, v); + if (!legal) { + continue; + } + before = before2; reshaped = re; break; } if (!reshaped) { - // llvm::errs() << " creating new reshape of arg " << v << "\n"; + if (auto rop = v.getDefiningOp()) { + rewriter.setInsertionPointAfter(rop); + } else if (auto ba = dyn_cast(v)) { + rewriter.setInsertionPointToStart(ba.getOwner()); + } reshaped = rewriter.create(op.getLoc(), NT, v); } else { - auto &&[legal, before] = fastDoesADominateB(op, reshaped, v); - if (legal) { - if (before) { - // llvm::errs() << " moved reshape " << reshaped << " of arg " << v - // << "\n"; - rewriter.modifyOpInPlace(reshaped, - [&]() { reshaped->moveBefore(op); }); - } - } else { - // llvm::errs() << " non block reshape reshape " << reshaped << " of - // arg " << v << "\n"; + if (before) { if (auto rop = v.getDefiningOp()) { - rewriter.setInsertionPointAfter(rop); - } else if (auto ba = dyn_cast(v)) { - rewriter.setInsertionPointToStart(ba.getOwner()); + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveAfter(rop); }); + } else { + rewriter.modifyOpInPlace(reshaped, + [&]() { reshaped->moveBefore(elem); }); } - reshaped = rewriter.create(op.getLoc(), NT, v); } } ops.push_back(reshaped); } if (singleUse) { - // llvm::errs() << " modifying in place\n"; rewriter.modifyOpInPlace(elem, [&]() { elem->setOperands(ops); elem->getResult(0).setType(op.getType()); }); - rewriter.replaceOp(op, elem); + for (auto re : toReplace) + rewriter.replaceOp(re, elem); } else { - rewriter.setInsertionPointAfter(op); + rewriter.setInsertionPointAfter(elem); auto newOp = rewriter.create( elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops), TypeRange(op.getType()), elem->getAttrs(), {}, {}); - // llvm::errs() << " created reshaped elem: " << newOp << "\n"; - rewriter.replaceOp(op, newOp); + for (auto re : toReplace) + rewriter.replaceOp(re, newOp); } return success(); } diff --git a/test/lit_tests/reshapeelementwise.mlir b/test/lit_tests/reshapeelementwise.mlir index 46fffe346..851dd59a2 100644 --- a/test/lit_tests/reshapeelementwise.mlir +++ b/test/lit_tests/reshapeelementwise.mlir @@ -14,9 +14,9 @@ module { } // CHECK: func.func @main(%arg0: tensor<100x200x300xbf16>, %arg1: tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> { -// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> -// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> -// CHECK-NEXT: %2 = stablehlo.subtract %0, %1 : tensor<20000x300xbf16> +// CHECK-DAG: %[[a0:.+]] = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> +// CHECK-DAG: %[[a1:.+]] = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> +// CHECK-NEXT: %2 = stablehlo.subtract %[[a0]], %[[a1]] : tensor<20000x300xbf16> // CHECK-NEXT: return %2 : tensor<20000x300xbf16> // CHECK-NEXT: } // CHECK: func.func @main2(%arg0: tensor<100x200x300xbf16>) -> tensor<20000x300xf32> {