Skip to content

CompileTime: fix reshape elementwise #680

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
247 changes: 216 additions & 31 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3746,6 +3746,78 @@ struct BroadcastReshape final
}
};

// Returns legal, and if reshaped comes before op
std::pair<bool, bool> fastDoesADominateB(Operation *reshaped, Operation *op,
Value v) {
assert(reshaped);
assert(op);
size_t limit = 200;
if (reshaped->getBlock() == op->getBlock()) {

// TODO we could do the following, if size wasn't O(N) =/
// op->getBlock()->getOperations().size() <= limit) {
if (op->getBlock()->isOpOrderValid()) {
return std::make_pair(true, reshaped->isBeforeInBlock(op));
}
if (v)
if (auto pred = v.getDefiningOp()) {
bool seenReshape = false;
bool seenUser = false;
Operation *cur = pred->getNextNode();
for (int i = 0; cur && i < limit; i++) {
// TODO we could make this an isancestor query, but of course compile
// time if (cur->isAncestor(reshaped))
if (cur == reshaped)
seenReshape = true;
// if (cur->isAncestor(op))
if (cur == op) {
seenUser = true;
}
if (seenReshape || seenUser)
break;
cur = cur->getNextNode();
}
if (seenReshape && !seenUser) {
return std::make_pair(true, true);
}
if (!seenReshape && seenUser) {
return std::make_pair(true, false);
}
}
{
bool seenUser = false;
Operation *cur = reshaped->getNextNode();
for (int i = 0; cur && i < limit; i++) {
// if (cur->isAncestor(op))
if (cur == op) {
seenUser = true;
return std::make_pair(true, true);
}
cur = cur->getNextNode();
}
if (!cur) {
std::make_pair(true, false);
}
}
{
bool seenReshape = false;
Operation *cur = op->getNextNode();
for (int i = 0; cur && i < limit; i++) {
// if (cur->isAncestor(reshaped))
if (cur == reshaped) {
seenReshape = true;
return std::make_pair(true, false);
}
cur = cur->getNextNode();
}
if (!cur) {
std::make_pair(true, true);
}
}
}
return std::make_pair(false, false);
}

struct BroadcastToReshape final
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -3787,9 +3859,39 @@ struct BroadcastToReshape final
// replace with reshape
if (op.getType() == op.getOperand().getType())
rewriter.replaceOp(op, op.getOperand());
else
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
op.getOperand());
else {
auto NT = op.getType();
stablehlo::ReshapeOp reshaped = nullptr;
bool before = false;
for (auto u : op.getOperand().getUsers()) {
auto re = dyn_cast<stablehlo::ReshapeOp>(u);
if (!re)
continue;
if (re.getType() != NT)
continue;
auto &&[legal, before2] = fastDoesADominateB(op, re, op.getOperand());
if (!legal)
continue;
before = before2;
reshaped = re;
break;
}
if (!reshaped) {
if (auto rop = op.getOperand().getDefiningOp()) {
rewriter.setInsertionPointAfter(rop);
} else if (auto ba = dyn_cast<BlockArgument>(op.getOperand())) {
rewriter.setInsertionPointToStart(ba.getOwner());
}
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(),
op.getOperand());
} else {
if (before) {
rewriter.modifyOpInPlace(reshaped,
[&]() { reshaped->moveBefore(op); });
}
rewriter.replaceOp(op, reshaped);
}
}
return success();
}
};
Expand Down Expand Up @@ -7495,29 +7597,96 @@ struct ReshapeElementwise final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {

LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op,
PatternRewriter &rewriter) const override {
if (op.getType() == op.getOperand().getType()) {
rewriter.replaceOp(op, op.getOperand());
return success();
}
auto elem = op.getOperand().getDefiningOp();
if (!elem)
return failure();

if (onlySingleUser && !llvm::hasSingleElement(elem->getUsers()))
if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
return failure();

if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
bool singleUse = true;
SmallVector<stablehlo::ReshapeOp> toReplace;
for (auto U : elem->getUsers()) {
if (auto re = dyn_cast<stablehlo::ReshapeOp>(U)) {
if (re.getType() == op.getType()) {
toReplace.push_back(re);
continue;
}
}
singleUse = false;
break;
}

if (onlySingleUser && !singleUse)
return failure();

if (singleUse) {
auto pt = rewriter.getInsertionPoint();
pt--;
rewriter.setInsertionPoint(rewriter.getInsertionBlock(), pt);
}

SmallVector<Value> ops;
for (auto v : elem->getOperands()) {
ops.push_back(rewriter.create<stablehlo::ReshapeOp>(
op.getLoc(),
RankedTensorType::get(
op.getType().getShape(),
cast<RankedTensorType>(v.getType()).getElementType()),
v));
auto NT = RankedTensorType::get(
op.getType().getShape(),
cast<RankedTensorType>(v.getType()).getElementType());
stablehlo::ReshapeOp reshaped = nullptr;
bool before;
for (auto u : v.getUsers()) {
auto re = dyn_cast<stablehlo::ReshapeOp>(u);
if (!re)
continue;
if (re.getType() != NT)
continue;
auto &&[legal, before2] = fastDoesADominateB(elem, re, v);
if (!legal) {
continue;
}
before = before2;
reshaped = re;
break;
}
if (!reshaped) {
if (auto rop = v.getDefiningOp()) {
rewriter.setInsertionPointAfter(rop);
} else if (auto ba = dyn_cast<BlockArgument>(v)) {
rewriter.setInsertionPointToStart(ba.getOwner());
}
reshaped = rewriter.create<stablehlo::ReshapeOp>(op.getLoc(), NT, v);
} else {
if (before) {
if (auto rop = v.getDefiningOp()) {
rewriter.modifyOpInPlace(reshaped,
[&]() { reshaped->moveAfter(rop); });
} else {
rewriter.modifyOpInPlace(reshaped,
[&]() { reshaped->moveBefore(elem); });
}
}
}
ops.push_back(reshaped);
}

if (singleUse) {
rewriter.modifyOpInPlace(elem, [&]() {
elem->setOperands(ops);
elem->getResult(0).setType(op.getType());
});
for (auto re : toReplace)
rewriter.replaceOp(re, elem);
} else {
rewriter.setInsertionPointAfter(elem);
auto newOp = rewriter.create(
elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops),
TypeRange(op.getType()), elem->getAttrs(), {}, {});
for (auto re : toReplace)
rewriter.replaceOp(re, newOp);
}
auto newOp = rewriter.create(
elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops),
TypeRange(op.getType()), elem->getAttrs(), {}, {});
rewriter.replaceOp(op, newOp);
return success();
}
};
Expand Down Expand Up @@ -7720,12 +7889,15 @@ template <typename T> struct CSE final : OpRewritePattern<T> {
continue;
if (nop->getBlock() != op->getBlock())
continue;
if (nop->isBeforeInBlock(op)) {
rewriter.replaceOp(op, nop);
return success();
} else {
rewriter.replaceOp(nop, op);
return success();
auto &&[legal, before] = fastDoesADominateB(nop, op, nullptr);
if (legal) {
if (before) {
rewriter.replaceOp(op, nop);
return success();
} else {
rewriter.replaceOp(nop, op);
return success();
}
}
}
return failure();
Expand Down Expand Up @@ -12421,16 +12593,19 @@ struct CommonCompareExpressionRewrite
continue;

if (userCompareOp.getLhs() == lhs && userCompareOp.getRhs() == rhs) {
if (user->isBeforeInBlock(op)) {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
op.getLoc(), userCompareOp.getResult());
rewriter.replaceOp(op, negatedCondition);
return success();
} else {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
userCompareOp.getLoc(), op.getResult());
rewriter.replaceOp(user, negatedCondition);
return success();
auto &&[legal, before] = fastDoesADominateB(user, op, opOperand);
if (legal) {
if (before) {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
op.getLoc(), userCompareOp.getResult());
rewriter.replaceOp(op, negatedCondition);
return success();
} else {
auto negatedCondition = rewriter.create<stablehlo::NotOp>(
userCompareOp.getLoc(), op.getResult());
rewriter.replaceOp(user, negatedCondition);
return success();
}
}
}
}
Expand Down Expand Up @@ -14328,6 +14503,16 @@ struct EnzymeHLOOptPass
GreedyRewriteConfig config;
config.maxIterations = max_iterations;
config.useTopDownTraversal = top_down;
getOperation()->walk([](Operation *op) {
for (auto &region : op->getRegions()) {
for (auto &blk : region.getBlocks()) {

if (!blk.isOpOrderValid()) {
blk.recomputeOpOrder();
}
}
}
});
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
signalPassFailure();
Expand Down
6 changes: 3 additions & 3 deletions test/lit_tests/raising/affine_to_stablehlo13.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ module {
}
}
// CHECK: func.func private @repeat_iv_raised(%arg0: tensor<10xi64>, %arg1: tensor<10xi64>, %arg2: tensor<10x10xf64>, %arg3: tensor<10xf64>) -> (tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64>) {
// CHECK-NEXT: %0 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %1 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64>
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<10xi64>) -> tensor<10x1xi64>
// CHECK-NEXT: %2 = stablehlo.concatenate %1, %0, dim = 1 : (tensor<10x1xi64>, tensor<10x1xi64>) -> tensor<10x2xi64>
// CHECK-NEXT: %3 = "stablehlo.gather"(%arg2, %2) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<10x10xf64>, tensor<10x2xi64>) -> tensor<10xf64>
// CHECK-NEXT: return %arg0, %arg1, %arg2, %3 : tensor<10xi64>, tensor<10xi64>, tensor<10x10xf64>, tensor<10xf64>
// CHECK-NEXT: }
Expand Down
6 changes: 3 additions & 3 deletions test/lit_tests/raising/affine_to_stablehlo15.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ module {
// CHECK-NEXT: %1 = stablehlo.dynamic_slice %arg0, %iterArg, %c_1, sizes = [1, 10] : (tensor<4x10xf32>, tensor<i64>, tensor<i64>) -> tensor<1x10xf32>
// CHECK-NEXT: %2 = stablehlo.reshape %1 : (tensor<1x10xf32>) -> tensor<10xf32>
// CHECK-NEXT: %3 = arith.mulf %2, %2 : tensor<10xf32>
// CHECK-NEXT: %4 = stablehlo.multiply %iterArg, %c_0 : tensor<i64>
// CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32>
// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %5, %4, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor<i64>, tensor<i64>) -> tensor<16x10xf32>
// CHECK-NEXT: %4 = stablehlo.reshape %3 : (tensor<10xf32>) -> tensor<1x10xf32>
// CHECK-NEXT: %5 = stablehlo.multiply %iterArg, %c_0 : tensor<i64>
// CHECK-NEXT: %6 = stablehlo.dynamic_update_slice %iterArg_2, %4, %5, %c_1 : (tensor<16x10xf32>, tensor<1x10xf32>, tensor<i64>, tensor<i64>) -> tensor<16x10xf32>
// CHECK-NEXT: %7 = stablehlo.add %iterArg, %c : tensor<i64>
// CHECK-NEXT: stablehlo.return %7, %6 : tensor<i64>, tensor<16x10xf32>
// CHECK-NEXT: }
Expand Down
12 changes: 6 additions & 6 deletions test/lit_tests/raising/affine_to_stablehlo_pforred.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ module @"reactant_loop!" attributes {mhlo.num_partitions = 1 : i64, mhlo.num_rep
// CHECK-NEXT: %15 = stablehlo.reduce(%12 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor<f64>) -> tensor<20x45xf64>
// CHECK-NEXT: %16 = stablehlo.reduce(%14 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<9x20x45xf64>, tensor<f64>) -> tensor<20x45xf64>
// CHECK-NEXT: %17 = arith.addf %5, %15 {fastmathFlags = #llvm.fastmath<none>} : tensor<20x45xf64>
// CHECK-NEXT: %18 = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath<none>} : tensor<20x45xf64>
// CHECK-NEXT: %19 = stablehlo.reshape %18 : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %20 = stablehlo.dynamic_update_slice %arg1, %19, %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x35x59xf64>
// CHECK-NEXT: %21 = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %21, %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x34x59xf64>
// CHECK-NEXT: return %22, %20, %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64>
// CHECK-NEXT: %[[i21:.+]] = stablehlo.reshape %17 : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %[[i18:.+]] = arith.addf %8, %16 {fastmathFlags = #llvm.fastmath<none>} : tensor<20x45xf64>
// CHECK-NEXT: %[[i19:.+]] = stablehlo.reshape %[[i18]] : (tensor<20x45xf64>) -> tensor<1x20x45xf64>
// CHECK-NEXT: %[[i20:.+]] = stablehlo.dynamic_update_slice %arg1, %[[i19]], %c_0, %c, %c : (tensor<1x35x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x35x59xf64>
// CHECK-NEXT: %22 = stablehlo.dynamic_update_slice %arg0, %[[i21]], %c_0, %c, %c : (tensor<1x34x59xf64>, tensor<1x20x45xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x34x59xf64>
// CHECK-NEXT: return %22, %[[i20]], %arg2, %arg3, %arg4 : tensor<1x34x59xf64>, tensor<1x35x59xf64>, tensor<24xf64>, tensor<24x34x59xf64>, tensor<24x35x59xf64>
// CHECK-NEXT: }
13 changes: 6 additions & 7 deletions test/lit_tests/raising/affine_to_stablehlo_pforred2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,10 @@ func.func private @"##call__Z29gpu__compute_barotropic_mode_16CompilerMetadataI1
// CHECK-NEXT: %50 = stablehlo.reduce(%44 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor<f64>) -> tensor<96x192xf64>
// CHECK-NEXT: %51 = stablehlo.reduce(%49 init: %cst_4) applies stablehlo.add across dimensions = [2] : (tensor<96x192x19xf64>, tensor<f64>) -> tensor<96x192xf64>
// CHECK-NEXT: %52 = arith.addf %33, %50 {fastmathFlags = #llvm.fastmath<none>} : tensor<96x192xf64>
// CHECK-NEXT: %53 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath<none>} : tensor<96x192xf64>
// CHECK-NEXT: %54 = stablehlo.reshape %53 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %55 = stablehlo.dynamic_update_slice %arg1, %54, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: %56 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %56, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: return %57, %55, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64>
// CHECK-NEXT: %53 = stablehlo.reshape %52 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %54 = arith.addf %37, %51 {fastmathFlags = #llvm.fastmath<none>} : tensor<96x192xf64>
// CHECK-NEXT: %55 = stablehlo.reshape %54 : (tensor<96x192xf64>) -> tensor<1x96x192xf64>
// CHECK-NEXT: %56 = stablehlo.dynamic_update_slice %arg1, %55, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: %57 = stablehlo.dynamic_update_slice %arg0, %53, %c_3, %c_2, %c : (tensor<1x140x206xf64>, tensor<1x96x192xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x140x206xf64>
// CHECK-NEXT: return %57, %56, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<1x140x206xf64>, tensor<1x140x206xf64>, tensor<35xf64>, tensor<34xf64>, tensor<1x110x206xf64>, tensor<34x110x206xf64>, tensor<34x110x206xf64>
// CHECK-NEXT: }

6 changes: 3 additions & 3 deletions test/lit_tests/reshapeelementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ module {
}

// CHECK: func.func @main(%arg0: tensor<100x200x300xbf16>, %arg1: tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> {
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-NEXT: %2 = stablehlo.subtract %0, %1 : tensor<20000x300xbf16>
// CHECK-DAG: %[[a0:.+]] = stablehlo.reshape %arg0 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-DAG: %[[a1:.+]] = stablehlo.reshape %arg1 : (tensor<100x200x300xbf16>) -> tensor<20000x300xbf16>
// CHECK-NEXT: %2 = stablehlo.subtract %[[a0]], %[[a1]] : tensor<20000x300xbf16>
// CHECK-NEXT: return %2 : tensor<20000x300xbf16>
// CHECK-NEXT: }
// CHECK: func.func @main2(%arg0: tensor<100x200x300xbf16>) -> tensor<20000x300xf32> {
Expand Down