Skip to content

Commit 992b575

Browse files
committed
[mlir][spirv][gpu] Convert remaining wmma ops to KHR coop matrix
These do not produce extension-specific ops and are handled via common patterns for both the KHR and the NV coop matrix extension. Also improve match failure reporting and error handling in type conversion.
1 parent ed4daea commit 992b575

File tree

2 files changed

+224
-103
lines changed

2 files changed

+224
-103
lines changed

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

+129-102
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,17 @@
2424
#include "mlir/IR/BuiltinAttributes.h"
2525
#include "mlir/IR/BuiltinTypes.h"
2626
#include "mlir/IR/TypeUtilities.h"
27+
#include "mlir/IR/ValueRange.h"
28+
#include "llvm/ADT/STLExtras.h"
2729
#include "llvm/ADT/StringSwitch.h"
2830

2931
#include <cassert>
3032

3133
namespace mlir {
34+
//===----------------------------------------------------------------------===//
35+
// Patterns and helpers used by both the KHR and the NV lowering paths.
36+
//===----------------------------------------------------------------------===//
37+
3238
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
3339
/// when the elementwise op directly supports with cooperative matrix type.
3440
/// Returns false if cannot.
@@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
7783
return false;
7884
}
7985

86+
bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
87+
assert(!operands.empty());
88+
if (!llvm::all_equal(
89+
llvm::map_range(operands, [](Value v) { return v.getType(); })))
90+
return false;
91+
92+
return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
93+
operands.front().getType());
94+
}
95+
96+
namespace {
97+
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
98+
/// matrix ops.
99+
struct WmmaConstantOpToSPIRVLowering final
100+
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
101+
using OpConversionPattern::OpConversionPattern;
102+
103+
LogicalResult
104+
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
105+
ConversionPatternRewriter &rewriter) const override {
106+
assert(adaptor.getOperands().size() == 1);
107+
Value cst = adaptor.getOperands().front();
108+
auto coopType = getTypeConverter()->convertType(op.getType());
109+
if (!coopType)
110+
return rewriter.notifyMatchFailure(op, "type conversion failed");
111+
112+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
113+
return success();
114+
}
115+
};
116+
117+
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
118+
/// the default case.
119+
struct WmmaElementwiseOpToSPIRVDefaultLowering final
120+
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
121+
using OpConversionPattern::OpConversionPattern;
122+
123+
LogicalResult
124+
matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
125+
ConversionPatternRewriter &rewriter) const override {
126+
// All operands should be of cooperative matrix types.
127+
if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
128+
return rewriter.notifyMatchFailure(op,
129+
"not all operands are coop matrices");
130+
}
131+
132+
auto coopType = getTypeConverter()->convertType(op.getType());
133+
if (!coopType)
134+
return rewriter.notifyMatchFailure(op, "type conversion failed");
135+
136+
return success(
137+
createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
138+
}
139+
};
140+
141+
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
142+
/// matrix times scalar case.
143+
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
144+
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
145+
using OpConversionPattern::OpConversionPattern;
146+
147+
LogicalResult
148+
matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
149+
ConversionPatternRewriter &rewriter) const override {
150+
if (adaptor.getOperands().size() != 2)
151+
return failure();
152+
153+
// All operands should be of cooperative matrix types.
154+
if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
155+
return rewriter.notifyMatchFailure(op,
156+
"not all operands are coop matrices");
157+
}
158+
159+
if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
160+
return failure();
161+
162+
// Use the original operands to check whether one of the operands is a splat
163+
// scalar value.
164+
Value lhs = op.getOperands().front();
165+
Value rhs = op.getOperands().back();
166+
Value splat = nullptr;
167+
Value matrix = nullptr;
168+
if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
169+
splat = adaptor.getOperands().front();
170+
matrix = adaptor.getOperands().back();
171+
} else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
172+
matrix = adaptor.getOperands().front();
173+
splat = adaptor.getOperands().back();
174+
}
175+
if (!splat || !matrix)
176+
return rewriter.notifyMatchFailure(op, "no splat operand");
177+
178+
// Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
179+
Value scalar;
180+
auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
181+
if (!cc) {
182+
return rewriter.notifyMatchFailure(op,
183+
"splat is not a composite construct");
184+
}
185+
186+
assert(cc.getConstituents().size() == 1);
187+
scalar = cc.getConstituents().front();
188+
189+
auto coopType = getTypeConverter()->convertType(op.getType());
190+
if (!coopType)
191+
return rewriter.notifyMatchFailure(op, "type conversion failed");
192+
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
193+
op, coopType, ValueRange{matrix, scalar});
194+
return success();
195+
}
196+
};
197+
} // namespace
198+
80199
//===----------------------------------------------------------------------===//
81200
// SPV_KHR_cooperative_matrix
82201
//===----------------------------------------------------------------------===//
@@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final
262381
}
263382
};
264383

265-
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
266-
/// ops.
267-
struct WmmaConstantOpToSPIRVLowering final
268-
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
269-
using OpConversionPattern::OpConversionPattern;
270-
271-
LogicalResult
272-
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
273-
OpAdaptor adaptor,
274-
ConversionPatternRewriter &rewriter) const override {
275-
Value cst = adaptor.getOperands()[0];
276-
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
277-
cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
278-
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
279-
subgroupMmaConstantMatrixOp, coopType, cst);
280-
return success();
281-
}
282-
};
283-
284-
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
285-
/// the default case.
286-
struct WmmaElementwiseOpToSPIRVDefaultLowering final
287-
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
288-
using OpConversionPattern::OpConversionPattern;
289-
290-
LogicalResult
291-
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
292-
OpAdaptor adaptor,
293-
ConversionPatternRewriter &rewriter) const override {
294-
// All operands should be of cooperative matrix types.
295-
for (Value operand : adaptor.getOperands()) {
296-
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
297-
return failure();
298-
}
299-
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
300-
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
301-
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
302-
adaptor.getOperands()));
303-
}
304-
};
305-
306-
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
307-
/// matrix times scalar case.
308-
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
309-
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
310-
using OpConversionPattern::OpConversionPattern;
311-
312-
LogicalResult
313-
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
314-
OpAdaptor adaptor,
315-
ConversionPatternRewriter &rewriter) const override {
316-
if (adaptor.getOperands().size() != 2)
317-
return failure();
318-
// All operands should be of cooperative matrix types.
319-
for (Value operand : adaptor.getOperands()) {
320-
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
321-
return failure();
322-
}
323-
324-
if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF)
325-
return failure();
326-
327-
// Use the original operands to check whether one of the operands is a splat
328-
// scalar value.
329-
Value lhs = elementwiseOp.getOperands().front();
330-
Value rhs = elementwiseOp.getOperands().back();
331-
Value splat = nullptr;
332-
Value matrix = nullptr;
333-
if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
334-
splat = adaptor.getOperands().front();
335-
matrix = adaptor.getOperands().back();
336-
} else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
337-
matrix = adaptor.getOperands().front();
338-
splat = adaptor.getOperands().back();
339-
}
340-
if (!splat || !matrix)
341-
return failure();
342-
343-
// Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
344-
Value scalar = nullptr;
345-
auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
346-
if (!cc)
347-
return failure();
348-
assert(cc.getConstituents().size() == 1);
349-
scalar = cc.getConstituents().front();
350-
351-
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
352-
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
353-
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
354-
elementwiseOp, coopType, ValueRange{matrix, scalar});
355-
return success();
356-
}
357-
};
358-
359384
} // namespace
360385
} // namespace nv
361386
} // namespace mlir
@@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
389414
using namespace mlir;
390415
MLIRContext *context = patterns.getContext();
391416
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
392-
khr::WmmaStoreOpToSPIRVLowering>(converter, context);
417+
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
418+
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
419+
// Give the following patterns higher benefit to prevail over the default one.
420+
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
421+
/*benefit=*/2);
393422
}
394423

395424
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
396425
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
397426
using namespace mlir;
398427
MLIRContext *context = patterns.getContext();
399-
patterns
400-
.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
401-
nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
402-
nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
428+
patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
429+
nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
430+
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
403431
// Give the following patterns higher benefit to prevail over the default one.
404-
patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
405-
context,
406-
/*benefit=*/2);
432+
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
433+
/*benefit=*/2);
407434
}

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

+95-1
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,106 @@ module attributes {
6969
-> !gpu.mma_matrix<16x16xf16, "COp">
7070

7171
%i = arith.constant 0 : index
72-
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
72+
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, <RowMajor>
7373
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
7474
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
7575
// CHECK: spirv.Return
7676
gpu.return
7777
}
7878

79+
// CHECK-LABEL: spirv.func @gpu_wmma_constant_op
80+
gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
81+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
82+
// CHECK: %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16
83+
%cst = arith.constant 1.0 : f16
84+
// CHECK: %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] :
85+
// CHECK-SAME: (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
86+
%C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
87+
88+
%i = arith.constant 0 : index
89+
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, <RowMajor>
90+
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
91+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
92+
// CHECK: spirv.Return
93+
gpu.return
94+
}
95+
96+
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
97+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
98+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
99+
gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
100+
%B: !gpu.mma_matrix<16x16xf16, "COp">,
101+
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
102+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
103+
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
104+
%C = gpu.subgroup_mma_elementwise addf %A, %B :
105+
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
106+
// CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
107+
%D = gpu.subgroup_mma_elementwise negatef %C :
108+
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
109+
// CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
110+
%E = gpu.subgroup_mma_elementwise divf %D, %A :
111+
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
112+
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} :
113+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
114+
%F = gpu.subgroup_mma_elementwise extf %E :
115+
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
116+
117+
%i = arith.constant 0 : index
118+
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
119+
gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} :
120+
!gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
121+
// CHECK: spirv.Return
122+
gpu.return
123+
}
124+
125+
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
126+
// CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
127+
// CHECK-SAME: %[[S:.+]]: f16
128+
gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(
129+
%A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16,
130+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
131+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
132+
%i = arith.constant 0 : index
133+
134+
%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
135+
// CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
136+
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
137+
%C = gpu.subgroup_mma_elementwise mulf %A, %B :
138+
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
139+
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
140+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
141+
142+
// CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
143+
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, <RowMajor>
144+
%D = gpu.subgroup_mma_elementwise mulf %B, %C :
145+
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
146+
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
147+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
148+
// CHECK: spirv.Return
149+
gpu.return
150+
}
151+
152+
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
153+
// CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
154+
// CHECK-SAME: %[[S:.+]]: f16
155+
gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(
156+
%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16,
157+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
158+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
159+
%i = arith.constant 0 : index
160+
161+
// CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
162+
%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
163+
// CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
164+
%C = gpu.subgroup_mma_elementwise addf %A, %B :
165+
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
166+
167+
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
168+
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
169+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
170+
// CHECK: spirv.Return
171+
gpu.return
172+
}
79173
}
80174
}

0 commit comments

Comments
 (0)