Skip to content

Commit f65729e

Browse files
committed
Added index cast optimizer for affine.if
1 parent 679a74f commit f65729e

File tree

2 files changed

+314
-0
lines changed

2 files changed

+314
-0
lines changed

src/enzyme_ad/jax/Passes/AffineCFG.cpp

+151
Original file line numberDiff line numberDiff line change
@@ -5211,6 +5211,156 @@ static bool isLoopParallel(AffineForOp forOp,
52115211
return ::isLoopMemoryParallel(forOp);
52125212
}
52135213

5214+
/// Moves index_cast operations inside affine.if regions when they are
5215+
/// applied to the if operation's results.
5216+
struct AffineIfIndexCastOptimizer
5217+
: public OpRewritePattern<mlir::affine::AffineIfOp> {
5218+
using OpRewritePattern<mlir::affine::AffineIfOp>::OpRewritePattern;
5219+
5220+
LogicalResult matchAndRewrite(mlir::affine::AffineIfOp ifOp,
5221+
PatternRewriter &rewriter) const override {
5222+
// Track which results are used only in index_cast operations
5223+
SmallVector<bool> resultsToOptimize(ifOp.getNumResults(), false);
5224+
SmallVector<Type> newResultTypes(ifOp.getNumResults());
5225+
SmallVector<arith::IndexCastOp> indexCastOps(ifOp.getNumResults(), nullptr);
5226+
bool hasOptimizableResult = false;
5227+
5228+
// Check each result to see if it's used in a single index_cast operation
5229+
for (auto result : llvm::enumerate(ifOp.getResults())) {
5230+
int resultIdx = result.index();
5231+
Value value = result.value();
5232+
5233+
// Skip if the result is not an integer type that can be index-cast
5234+
if (!value.getType().isIntOrIndex()) {
5235+
newResultTypes[resultIdx] = value.getType();
5236+
continue;
5237+
}
5238+
5239+
// Skip dead results
5240+
if (value.use_empty()) {
5241+
newResultTypes[resultIdx] = value.getType();
5242+
continue;
5243+
}
5244+
5245+
// Check if this result is only used by a single index_cast operation
5246+
arith::IndexCastOp indexCastUser = nullptr;
5247+
bool singleUser = true;
5248+
5249+
for (Operation *user : value.getUsers()) {
5250+
if (auto indexCast = dyn_cast<arith::IndexCastOp>(user)) {
5251+
if (!indexCastUser) {
5252+
indexCastUser = indexCast;
5253+
} else {
5254+
// More than one index_cast uses this result
5255+
singleUser = false;
5256+
break;
5257+
}
5258+
} else {
5259+
// Used by something other than index_cast
5260+
singleUser = false;
5261+
break;
5262+
}
5263+
}
5264+
5265+
if (singleUser && indexCastUser) {
5266+
// Result is only used by one index_cast
5267+
resultsToOptimize[resultIdx] = true;
5268+
newResultTypes[resultIdx] = indexCastUser.getType();
5269+
indexCastOps[resultIdx] = indexCastUser;
5270+
hasOptimizableResult = true;
5271+
} else {
5272+
newResultTypes[resultIdx] = value.getType();
5273+
}
5274+
}
5275+
5276+
// If no optimizable results, return failure
5277+
if (!hasOptimizableResult)
5278+
return failure();
5279+
5280+
// Create a new if operation with the optimized regions
5281+
auto loc = ifOp.getLoc();
5282+
auto newIfOp = rewriter.create<mlir::affine::AffineIfOp>(
5283+
loc, newResultTypes, ifOp.getIntegerSet(), ifOp.getOperands(),
5284+
/* withElseRegion */ true);
5285+
5286+
// Update the then region
5287+
rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
5288+
newIfOp.getThenRegion().begin());
5289+
5290+
// Update the else region
5291+
rewriter.inlineRegionBefore(ifOp.getElseRegion(), newIfOp.getElseRegion(),
5292+
newIfOp.getElseRegion().begin());
5293+
5294+
// Get the yield ops in both regions
5295+
auto &thenBlock = newIfOp.getThenRegion().front();
5296+
auto thenYield =
5297+
cast<mlir::affine::AffineYieldOp>(thenBlock.getTerminator());
5298+
5299+
auto &elseBlock = newIfOp.getElseRegion().front();
5300+
auto elseYield =
5301+
cast<mlir::affine::AffineYieldOp>(elseBlock.getTerminator());
5302+
5303+
// Transform the yield operations to apply index_cast where needed
5304+
SmallVector<Value> thenYieldOperands;
5305+
SmallVector<Value> elseYieldOperands;
5306+
5307+
for (auto it : llvm::enumerate(resultsToOptimize)) {
5308+
size_t idx = it.index();
5309+
bool shouldOptimize = it.value();
5310+
5311+
if (shouldOptimize) {
5312+
// Insert index_cast before the yield
5313+
rewriter.setInsertionPoint(thenYield);
5314+
Value thenCasted = rewriter.create<arith::IndexCastOp>(
5315+
loc, indexCastOps[idx].getType(), thenYield.getOperand(idx));
5316+
thenYieldOperands.push_back(thenCasted);
5317+
5318+
rewriter.setInsertionPoint(elseYield);
5319+
Value elseCasted = rewriter.create<arith::IndexCastOp>(
5320+
loc, indexCastOps[idx].getType(), elseYield.getOperand(idx));
5321+
elseYieldOperands.push_back(elseCasted);
5322+
} else {
5323+
// Keep the original operand
5324+
thenYieldOperands.push_back(thenYield.getOperand(idx));
5325+
elseYieldOperands.push_back(elseYield.getOperand(idx));
5326+
}
5327+
}
5328+
5329+
// Replace the original yield operations with new ones
5330+
rewriter.setInsertionPoint(thenYield);
5331+
rewriter.replaceOpWithNewOp<mlir::affine::AffineYieldOp>(thenYield,
5332+
thenYieldOperands);
5333+
5334+
rewriter.setInsertionPoint(elseYield);
5335+
rewriter.replaceOpWithNewOp<mlir::affine::AffineYieldOp>(elseYield,
5336+
elseYieldOperands);
5337+
5338+
// Replace uses of the old index_cast ops with the corresponding results
5339+
// from the new if
5340+
for (auto it : llvm::enumerate(resultsToOptimize)) {
5341+
size_t idx = it.index();
5342+
bool shouldOptimize = it.value();
5343+
5344+
if (shouldOptimize && indexCastOps[idx]) {
5345+
rewriter.replaceOp(indexCastOps[idx], newIfOp.getResult(idx));
5346+
}
5347+
}
5348+
5349+
// Replace the remaining results
5350+
for (auto it : llvm::enumerate(newIfOp.getResults())) {
5351+
size_t idx = it.index();
5352+
if (!resultsToOptimize[idx]) {
5353+
ifOp.getResult(idx).replaceAllUsesWith(it.value());
5354+
}
5355+
}
5356+
5357+
// Erase the old if op
5358+
rewriter.eraseOp(ifOp);
5359+
5360+
return success();
5361+
}
5362+
};
5363+
52145364
struct AffineParallelizePattern : public OpRewritePattern<affine::AffineForOp> {
52155365

52165366
AffineParallelizePattern(bool parallelReductions, MLIRContext *context)
@@ -5316,4 +5466,5 @@ void populateAffineParallelizationPattern(MLIRContext &context,
53165466
RewritePatternSet &patterns) {
53175467
patterns.insert<AffineParallelizePattern>(/*parallelReductions=*/true,
53185468
&context);
5469+
patterns.insert<AffineIfIndexCastOptimizer>(&context);
53195470
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
// RUN: enzymexlamlir-opt %s -affine-cfg | FileCheck %s
2+
3+
// Define a simple affine set
4+
#set0 = affine_set<(d0) : (d0 >= 0)>
5+
6+
// Basic test case - single result with index_cast
7+
// CHECK-LABEL: func @test_affine_if_index_cast
8+
func.func @test_affine_if_index_cast(%arg0: index) -> index {
9+
%c1_i64 = arith.constant 1 : i64
10+
%c2_i64 = arith.constant 2 : i64
11+
12+
// CHECK: %[[RESULT:.*]] = affine.if
13+
// CHECK: %[[CAST1:.*]] = arith.index_cast %c1_i64 : i64 to index
14+
// CHECK: affine.yield %[[CAST1]]
15+
// CHECK: else
16+
// CHECK: %[[CAST2:.*]] = arith.index_cast %c2_i64 : i64 to index
17+
// CHECK: affine.yield %[[CAST2]]
18+
// CHECK-NOT: arith.index_cast
19+
20+
%0 = affine.if #set0(%arg0) -> (i64) {
21+
affine.yield %c1_i64 : i64
22+
} else {
23+
affine.yield %c2_i64 : i64
24+
}
25+
26+
%1 = arith.index_cast %0 : i64 to index
27+
return %1 : index
28+
}
29+
30+
// -----
31+
32+
// Multiple results with only one used in index_cast
33+
// CHECK-LABEL: func @test_multi_result_single_cast
34+
func.func @test_multi_result_single_cast(%arg0: index) -> (i64, index) {
35+
%c1_i64 = arith.constant 1 : i64
36+
%c2_i64 = arith.constant 2 : i64
37+
%c3_i64 = arith.constant 3 : i64
38+
%c4_i64 = arith.constant 4 : i64
39+
40+
// CHECK: %[[RESULT:.*]]:2 = affine.if
41+
// CHECK: affine.yield %c1_i64, %{{.*}}
42+
// CHECK: else
43+
// CHECK: %[[CAST:.*]] = arith.index_cast %c4_i64 : i64 to index
44+
// CHECK: affine.yield %c3_i64, %[[CAST]]
45+
// CHECK-NOT: arith.index_cast %{{.*}}#1
46+
47+
%0:2 = affine.if #set0(%arg0) -> (i64, i64) {
48+
affine.yield %c1_i64, %c2_i64 : i64, i64
49+
} else {
50+
affine.yield %c3_i64, %c4_i64 : i64, i64
51+
}
52+
53+
%1 = arith.index_cast %0#1 : i64 to index
54+
return %0#0, %1 : i64, index
55+
}
56+
57+
// -----
58+
59+
// Multiple results with all used in index_cast
60+
// CHECK-LABEL: func @test_multi_result_all_cast
61+
func.func @test_multi_result_all_cast(%arg0: index) -> (index, index) {
62+
%c1_i64 = arith.constant 1 : i64
63+
%c2_i64 = arith.constant 2 : i64
64+
%c3_i64 = arith.constant 3 : i64
65+
%c4_i64 = arith.constant 4 : i64
66+
67+
// CHECK: %[[RESULT:.*]]:2 = affine.if
68+
// CHECK: %[[CAST1:.*]] = arith.index_cast %c1_i64 : i64 to index
69+
// CHECK: %[[CAST2:.*]] = arith.index_cast %c2_i64 : i64 to index
70+
// CHECK: affine.yield %[[CAST1]], %[[CAST2]]
71+
// CHECK: else
72+
// CHECK: %[[CAST3:.*]] = arith.index_cast %c3_i64 : i64 to index
73+
// CHECK: %[[CAST4:.*]] = arith.index_cast %c4_i64 : i64 to index
74+
// CHECK: affine.yield %[[CAST3]], %[[CAST4]]
75+
76+
%0:2 = affine.if #set0(%arg0) -> (i64, i64) {
77+
affine.yield %c1_i64, %c2_i64 : i64, i64
78+
} else {
79+
affine.yield %c3_i64, %c4_i64 : i64, i64
80+
}
81+
82+
%1 = arith.index_cast %0#0 : i64 to index
83+
%2 = arith.index_cast %0#1 : i64 to index
84+
return %1, %2 : index, index
85+
}
86+
87+
// -----
88+
89+
// Case where result is used in multiple places - should not optimize
90+
// CHECK-LABEL: func @test_multiple_uses
91+
func.func @test_multiple_uses(%arg0: index) -> (index, i64) {
92+
%c1_i64 = arith.constant 1 : i64
93+
%c2_i64 = arith.constant 2 : i64
94+
95+
// CHECK: %[[RESULT:.*]] = affine.if
96+
// CHECK-NOT: arith.index_cast
97+
// CHECK: affine.yield %c1_i64
98+
// CHECK: else
99+
// CHECK-NOT: arith.index_cast
100+
// CHECK: affine.yield %c2_i64
101+
// CHECK: arith.index_cast %[[RESULT]] : i64 to index
102+
103+
%0 = affine.if #set0(%arg0) -> (i64) {
104+
affine.yield %c1_i64 : i64
105+
} else {
106+
affine.yield %c2_i64 : i64
107+
}
108+
109+
%1 = arith.index_cast %0 : i64 to index
110+
return %1, %0 : index, i64
111+
}
112+
113+
// -----
114+
115+
// Case with multiple index_casts applied to the same result - should not optimize
116+
// CHECK-LABEL: func @test_multiple_casts
117+
func.func @test_multiple_casts(%arg0: index) -> (index, index) {
118+
%c1_i64 = arith.constant 1 : i64
119+
%c2_i64 = arith.constant 2 : i64
120+
121+
// CHECK: %[[RESULT:.*]] = affine.if
122+
// CHECK-NOT: arith.index_cast
123+
// CHECK: affine.yield %c1_i64
124+
// CHECK: else
125+
// CHECK-NOT: arith.index_cast
126+
// CHECK: affine.yield %c2_i64
127+
// CHECK: %[[CAST1:.*]] = arith.index_cast %[[RESULT]] : i64 to index
128+
// CHECK: %[[CAST2:.*]] = arith.index_cast %[[RESULT]] : i64 to index
129+
130+
%0 = affine.if #set0(%arg0) -> (i64) {
131+
affine.yield %c1_i64 : i64
132+
} else {
133+
affine.yield %c2_i64 : i64
134+
}
135+
136+
%1 = arith.index_cast %0 : i64 to index
137+
%2 = arith.index_cast %0 : i64 to index
138+
return %1, %2 : index, index
139+
}
140+
141+
// -----
142+
143+
// Case where the result is already an index type - should not optimize
144+
// CHECK-LABEL: func @test_already_index
145+
func.func @test_already_index(%arg0: index) -> index {
146+
%c1 = arith.constant 1 : index
147+
%c2 = arith.constant 2 : index
148+
149+
// CHECK: %[[RESULT:.*]] = affine.if
150+
// CHECK-NOT: arith.index_cast
151+
// CHECK: affine.yield %c1
152+
// CHECK: else
153+
// CHECK-NOT: arith.index_cast
154+
// CHECK: affine.yield %c2
155+
156+
%0 = affine.if #set0(%arg0) -> (index) {
157+
affine.yield %c1 : index
158+
} else {
159+
affine.yield %c2 : index
160+
}
161+
162+
return %0 : index
163+
}

0 commit comments

Comments
 (0)