Skip to content

Commit c46a043

Browse files
[mlir][arith] Rename AtomicRMWKind's maxfmaximumf, minfminimumf (#66135)
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. This commit renames `maxf` and `minf` enumerators of `AtomicRMWKind` to better reflect the current naming scheme and the goals of the RFC.
1 parent f94695b commit c46a043

File tree

11 files changed

+43
-41
lines changed

11 files changed

+43
-41
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

+15-15
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,25 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
6969
let cppNamespace = "::mlir::arith";
7070
}
7171

72-
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
73-
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
74-
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
75-
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
76-
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
77-
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
78-
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
79-
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
80-
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
81-
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
82-
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
83-
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
84-
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
72+
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
73+
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
74+
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
75+
def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
76+
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
77+
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
78+
def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
79+
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
80+
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
81+
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
82+
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
83+
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
84+
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
8585

8686
def AtomicRMWKindAttr : I64EnumAttr<
8787
"AtomicRMWKind", "",
8888
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
89-
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
90-
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
89+
ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
90+
ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
9191
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
9292
ATOMIC_RMW_KIND_ANDI]> {
9393
let cppNamespace = "::mlir::arith";

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1594,13 +1594,13 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
15941594
return LLVM::AtomicBinOp::add;
15951595
case arith::AtomicRMWKind::assign:
15961596
return LLVM::AtomicBinOp::xchg;
1597-
case arith::AtomicRMWKind::maxf:
1597+
case arith::AtomicRMWKind::maximumf:
15981598
return LLVM::AtomicBinOp::fmax;
15991599
case arith::AtomicRMWKind::maxs:
16001600
return LLVM::AtomicBinOp::max;
16011601
case arith::AtomicRMWKind::maxu:
16021602
return LLVM::AtomicBinOp::umax;
1603-
case arith::AtomicRMWKind::minf:
1603+
case arith::AtomicRMWKind::minimumf:
16041604
return LLVM::AtomicBinOp::fmin;
16051605
case arith::AtomicRMWKind::mins:
16061606
return LLVM::AtomicBinOp::min;

mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
6060
.Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
6161
.Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
6262
.Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
63-
.Case([](arith::MinimumFOp) { return arith::AtomicRMWKind::minf; })
64-
.Case([](arith::MaximumFOp) { return arith::AtomicRMWKind::maxf; })
63+
.Case(
64+
[](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
65+
.Case(
66+
[](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
6567
.Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
6668
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
6769
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -2369,7 +2369,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
23692369
OpBuilder &builder, Location loc,
23702370
bool useOnlyFiniteValue) {
23712371
switch (kind) {
2372-
case AtomicRMWKind::maxf: {
2372+
case AtomicRMWKind::maximumf: {
23732373
const llvm::fltSemantics &semantic =
23742374
llvm::cast<FloatType>(resultType).getFloatSemantics();
23752375
APFloat identity = useOnlyFiniteValue
@@ -2390,7 +2390,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
23902390
return builder.getIntegerAttr(
23912391
resultType, APInt::getSignedMinValue(
23922392
llvm::cast<IntegerType>(resultType).getWidth()));
2393-
case AtomicRMWKind::minf: {
2393+
case AtomicRMWKind::minimumf: {
23942394
const llvm::fltSemantics &semantic =
23952395
llvm::cast<FloatType>(resultType).getFloatSemantics();
23962396
APFloat identity = useOnlyFiniteValue
@@ -2426,8 +2426,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
24262426
// Floating-point operations.
24272427
.Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
24282428
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2429-
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maxf; })
2430-
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minf; })
2429+
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2430+
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
24312431
// Integer operations.
24322432
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
24332433
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
@@ -2482,9 +2482,9 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
24822482
return builder.create<arith::MulFOp>(loc, lhs, rhs);
24832483
case AtomicRMWKind::muli:
24842484
return builder.create<arith::MulIOp>(loc, lhs, rhs);
2485-
case AtomicRMWKind::maxf:
2485+
case AtomicRMWKind::maximumf:
24862486
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2487-
case AtomicRMWKind::minf:
2487+
case AtomicRMWKind::minimumf:
24882488
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
24892489
case AtomicRMWKind::maxs:
24902490
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -2549,9 +2549,9 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
25492549
dims.erase(dims.begin() + reductionDim);
25502550
// Step 1: Compute max along dim.
25512551
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2552-
Value neutralForMaxF =
2553-
arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc,
2554-
/*useOnlyFiniteValue=*/true);
2552+
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2553+
elementType, b, loc,
2554+
/*useOnlyFiniteValue=*/true);
25552555
Value neutralForMaxFInit =
25562556
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
25572557
.result();

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -3402,8 +3402,8 @@ LogicalResult AtomicRMWOp::verify() {
34023402
"expects the number of subscripts to be equal to memref rank");
34033403
switch (getKind()) {
34043404
case arith::AtomicRMWKind::addf:
3405-
case arith::AtomicRMWKind::maxf:
3406-
case arith::AtomicRMWKind::minf:
3405+
case arith::AtomicRMWKind::maximumf:
3406+
case arith::AtomicRMWKind::minimumf:
34073407
case arith::AtomicRMWKind::mulf:
34083408
if (!llvm::isa<FloatType>(getValue().getType()))
34093409
return emitOpError() << "with kind '"

mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace {
3636
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
3737
/// `memref.generic_atomic_rmw` with the expanded code.
3838
///
39-
/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
39+
/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
4040
///
4141
/// will be lowered to
4242
///
@@ -54,10 +54,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
5454
PatternRewriter &rewriter) const final {
5555
arith::CmpFPredicate predicate;
5656
switch (op.getKind()) {
57-
case arith::AtomicRMWKind::maxf:
57+
case arith::AtomicRMWKind::maximumf:
5858
predicate = arith::CmpFPredicate::OGT;
5959
break;
60-
case arith::AtomicRMWKind::minf:
60+
case arith::AtomicRMWKind::minimumf:
6161
predicate = arith::CmpFPredicate::OLT;
6262
break;
6363
default:
@@ -137,8 +137,8 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
137137
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
138138
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
139139
[](memref::AtomicRMWOp op) {
140-
return op.getKind() != arith::AtomicRMWKind::maxf &&
141-
op.getKind() != arith::AtomicRMWKind::minf;
140+
return op.getKind() != arith::AtomicRMWKind::maximumf &&
141+
op.getKind() != arith::AtomicRMWKind::minimumf;
142142
});
143143
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
144144
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
493493
case arith::AtomicRMWKind::muli:
494494
return builder.create<vector::ReductionOp>(vector.getLoc(),
495495
CombiningKind::MUL, vector);
496-
case arith::AtomicRMWKind::minf:
496+
case arith::AtomicRMWKind::minimumf:
497497
return builder.create<vector::ReductionOp>(vector.getLoc(),
498498
CombiningKind::MINF, vector);
499499
case arith::AtomicRMWKind::mins:
@@ -502,7 +502,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
502502
case arith::AtomicRMWKind::minu:
503503
return builder.create<vector::ReductionOp>(vector.getLoc(),
504504
CombiningKind::MINUI, vector);
505-
case arith::AtomicRMWKind::maxf:
505+
case arith::AtomicRMWKind::maximumf:
506506
return builder.create<vector::ReductionOp>(vector.getLoc(),
507507
CombiningKind::MAXF, vector);
508508
case arith::AtomicRMWKind::maxs:

mlir/test/Dialect/Affine/invalid.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
287287

288288
func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
289289
%0 = memref.alloc() : memref<100x100xi32>
290-
%1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minf") -> (f32) {
290+
%1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minimumf") -> (f32) {
291291
%2 = affine.load %0[%i, %j] : memref<100x100xi32>
292292
// expected-error@+1 {{types mismatch between yield op and its parent}}
293293
affine.yield %2 : i32

mlir/test/Dialect/Affine/ops.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ func.func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {
158158
func.func @parallel(%A : memref<100x100xf32>, %N : index) {
159159
// CHECK: affine.parallel (%[[I0:.*]], %[[J0:.*]]) = (0, 0) to (symbol(%[[N]]), 100) step (10, 10)
160160
affine.parallel (%i0, %j0) = (0, 0) to (symbol(%N), 100) step (10, 10) {
161-
// CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minf", "maxf") -> (f32, f32)
162-
%0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minf", "maxf") -> (f32, f32) {
161+
// CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minimumf", "maximumf") -> (f32, f32)
162+
%0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minimumf", "maximumf") -> (f32, f32) {
163163
%2 = affine.load %A[%i0 + %i0, %j0 + %j1] : memref<100x100xf32>
164164
affine.yield %2, %2 : f32, f32
165165
}

mlir/test/Dialect/MemRef/expand-ops.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// CHECK-LABEL: func @atomic_rmw_to_generic
44
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
55
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
6-
%x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
6+
%x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
77
return %x : f32
88
}
99
// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {

0 commit comments

Comments
 (0)