|
24 | 24 | #include "mlir/IR/BuiltinAttributes.h"
|
25 | 25 | #include "mlir/IR/BuiltinTypes.h"
|
26 | 26 | #include "mlir/IR/TypeUtilities.h"
|
| 27 | +#include "mlir/IR/ValueRange.h" |
| 28 | +#include "llvm/ADT/STLExtras.h" |
27 | 29 | #include "llvm/ADT/StringSwitch.h"
|
28 | 30 |
|
29 | 31 | #include <cassert>
|
30 | 32 |
|
31 | 33 | namespace mlir {
|
| 34 | +//===----------------------------------------------------------------------===// |
| 35 | +// Patterns and helpers used by both the KHR and the NV lowering paths. |
| 36 | +//===----------------------------------------------------------------------===// |
| 37 | + |
32 | 38 | /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
|
33 | 39 | /// when the elementwise op directly supports with cooperative matrix type.
|
34 | 40 | /// Returns false if cannot.
|
@@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
|
77 | 83 | return false;
|
78 | 84 | }
|
79 | 85 |
|
| 86 | +bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { |
| 87 | + assert(!operands.empty()); |
| 88 | + if (!llvm::all_equal( |
| 89 | + llvm::map_range(operands, [](Value v) { return v.getType(); }))) |
| 90 | + return false; |
| 91 | + |
| 92 | + return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>( |
| 93 | + operands.front().getType()); |
| 94 | +} |
| 95 | + |
| 96 | +namespace { |
| 97 | +/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative |
| 98 | +/// matrix ops. |
| 99 | +struct WmmaConstantOpToSPIRVLowering final |
| 100 | + : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { |
| 101 | + using OpConversionPattern::OpConversionPattern; |
| 102 | + |
| 103 | + LogicalResult |
| 104 | + matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, |
| 105 | + ConversionPatternRewriter &rewriter) const override { |
| 106 | + assert(adaptor.getOperands().size() == 1); |
| 107 | + Value cst = adaptor.getOperands().front(); |
| 108 | + auto coopType = getTypeConverter()->convertType(op.getType()); |
| 109 | + if (!coopType) |
| 110 | + return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| 111 | + |
| 112 | + rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst); |
| 113 | + return success(); |
| 114 | + } |
| 115 | +}; |
| 116 | + |
| 117 | +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
| 118 | +/// the default case. |
| 119 | +struct WmmaElementwiseOpToSPIRVDefaultLowering final |
| 120 | + : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
| 121 | + using OpConversionPattern::OpConversionPattern; |
| 122 | + |
| 123 | + LogicalResult |
| 124 | + matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
| 125 | + ConversionPatternRewriter &rewriter) const override { |
| 126 | + // All operands should be of cooperative matrix types. |
| 127 | + if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
| 128 | + return rewriter.notifyMatchFailure(op, |
| 129 | + "not all operands are coop matrices"); |
| 130 | + } |
| 131 | + |
| 132 | + auto coopType = getTypeConverter()->convertType(op.getType()); |
| 133 | + if (!coopType) |
| 134 | + return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| 135 | + |
| 136 | + return success( |
| 137 | + createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); |
| 138 | + } |
| 139 | +}; |
| 140 | + |
| 141 | +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
| 142 | +/// matrix times scalar case. |
| 143 | +struct WmmaElementwiseOpToSPIRVScalarMulLowering final |
| 144 | + : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
| 145 | + using OpConversionPattern::OpConversionPattern; |
| 146 | + |
| 147 | + LogicalResult |
| 148 | + matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
| 149 | + ConversionPatternRewriter &rewriter) const override { |
| 150 | + if (adaptor.getOperands().size() != 2) |
| 151 | + return failure(); |
| 152 | + |
| 153 | + // All operands should be of cooperative matrix types. |
| 154 | + if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
| 155 | + return rewriter.notifyMatchFailure(op, |
| 156 | + "not all operands are coop matrices"); |
| 157 | + } |
| 158 | + |
| 159 | + if (op.getOpType() != gpu::MMAElementwiseOp::MULF) |
| 160 | + return failure(); |
| 161 | + |
| 162 | + // Use the original operands to check whether one of the operands is a splat |
| 163 | + // scalar value. |
| 164 | + Value lhs = op.getOperands().front(); |
| 165 | + Value rhs = op.getOperands().back(); |
| 166 | + Value splat = nullptr; |
| 167 | + Value matrix = nullptr; |
| 168 | + if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
| 169 | + splat = adaptor.getOperands().front(); |
| 170 | + matrix = adaptor.getOperands().back(); |
| 171 | + } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
| 172 | + matrix = adaptor.getOperands().front(); |
| 173 | + splat = adaptor.getOperands().back(); |
| 174 | + } |
| 175 | + if (!splat || !matrix) |
| 176 | + return rewriter.notifyMatchFailure(op, "no splat operand"); |
| 177 | + |
| 178 | + // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. |
| 179 | + Value scalar; |
| 180 | + auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); |
| 181 | + if (!cc) { |
| 182 | + return rewriter.notifyMatchFailure(op, |
| 183 | + "splat is not a composite construct"); |
| 184 | + } |
| 185 | + |
| 186 | + assert(cc.getConstituents().size() == 1); |
| 187 | + scalar = cc.getConstituents().front(); |
| 188 | + |
| 189 | + auto coopType = getTypeConverter()->convertType(op.getType()); |
| 190 | + if (!coopType) |
| 191 | + return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| 192 | + rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( |
| 193 | + op, coopType, ValueRange{matrix, scalar}); |
| 194 | + return success(); |
| 195 | + } |
| 196 | +}; |
| 197 | +} // namespace |
| 198 | + |
80 | 199 | //===----------------------------------------------------------------------===//
|
81 | 200 | // SPV_KHR_cooperative_matrix
|
82 | 201 | //===----------------------------------------------------------------------===//
|
@@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final
|
262 | 381 | }
|
263 | 382 | };
|
264 | 383 |
|
265 |
| -/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix |
266 |
| -/// ops. |
267 |
| -struct WmmaConstantOpToSPIRVLowering final |
268 |
| - : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { |
269 |
| - using OpConversionPattern::OpConversionPattern; |
270 |
| - |
271 |
| - LogicalResult |
272 |
| - matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp, |
273 |
| - OpAdaptor adaptor, |
274 |
| - ConversionPatternRewriter &rewriter) const override { |
275 |
| - Value cst = adaptor.getOperands()[0]; |
276 |
| - auto coopType = convertMMAToSPIRVCoopMatrixNVType( |
277 |
| - cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType())); |
278 |
| - rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( |
279 |
| - subgroupMmaConstantMatrixOp, coopType, cst); |
280 |
| - return success(); |
281 |
| - } |
282 |
| -}; |
283 |
| - |
284 |
| -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
285 |
| -/// the default case. |
286 |
| -struct WmmaElementwiseOpToSPIRVDefaultLowering final |
287 |
| - : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
288 |
| - using OpConversionPattern::OpConversionPattern; |
289 |
| - |
290 |
| - LogicalResult |
291 |
| - matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, |
292 |
| - OpAdaptor adaptor, |
293 |
| - ConversionPatternRewriter &rewriter) const override { |
294 |
| - // All operands should be of cooperative matrix types. |
295 |
| - for (Value operand : adaptor.getOperands()) { |
296 |
| - if (!isa<spirv::CooperativeMatrixNVType>(operand.getType())) |
297 |
| - return failure(); |
298 |
| - } |
299 |
| - auto coopType = convertMMAToSPIRVCoopMatrixNVType( |
300 |
| - cast<gpu::MMAMatrixType>(elementwiseOp.getType())); |
301 |
| - return success(createElementwiseOp(rewriter, elementwiseOp, coopType, |
302 |
| - adaptor.getOperands())); |
303 |
| - } |
304 |
| -}; |
305 |
| - |
306 |
| -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
307 |
| -/// matrix times scalar case. |
308 |
| -struct WmmaElementwiseOpToSPIRVScalarMulLowering final |
309 |
| - : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
310 |
| - using OpConversionPattern::OpConversionPattern; |
311 |
| - |
312 |
| - LogicalResult |
313 |
| - matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, |
314 |
| - OpAdaptor adaptor, |
315 |
| - ConversionPatternRewriter &rewriter) const override { |
316 |
| - if (adaptor.getOperands().size() != 2) |
317 |
| - return failure(); |
318 |
| - // All operands should be of cooperative matrix types. |
319 |
| - for (Value operand : adaptor.getOperands()) { |
320 |
| - if (!isa<spirv::CooperativeMatrixNVType>(operand.getType())) |
321 |
| - return failure(); |
322 |
| - } |
323 |
| - |
324 |
| - if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF) |
325 |
| - return failure(); |
326 |
| - |
327 |
| - // Use the original operands to check whether one of the operands is a splat |
328 |
| - // scalar value. |
329 |
| - Value lhs = elementwiseOp.getOperands().front(); |
330 |
| - Value rhs = elementwiseOp.getOperands().back(); |
331 |
| - Value splat = nullptr; |
332 |
| - Value matrix = nullptr; |
333 |
| - if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
334 |
| - splat = adaptor.getOperands().front(); |
335 |
| - matrix = adaptor.getOperands().back(); |
336 |
| - } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
337 |
| - matrix = adaptor.getOperands().front(); |
338 |
| - splat = adaptor.getOperands().back(); |
339 |
| - } |
340 |
| - if (!splat || !matrix) |
341 |
| - return failure(); |
342 |
| - |
343 |
| - // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops. |
344 |
| - Value scalar = nullptr; |
345 |
| - auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); |
346 |
| - if (!cc) |
347 |
| - return failure(); |
348 |
| - assert(cc.getConstituents().size() == 1); |
349 |
| - scalar = cc.getConstituents().front(); |
350 |
| - |
351 |
| - auto coopType = convertMMAToSPIRVCoopMatrixNVType( |
352 |
| - cast<gpu::MMAMatrixType>(elementwiseOp.getType())); |
353 |
| - rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( |
354 |
| - elementwiseOp, coopType, ValueRange{matrix, scalar}); |
355 |
| - return success(); |
356 |
| - } |
357 |
| -}; |
358 |
| - |
359 | 384 | } // namespace
|
360 | 385 | } // namespace nv
|
361 | 386 | } // namespace mlir
|
@@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
|
389 | 414 | using namespace mlir;
|
390 | 415 | MLIRContext *context = patterns.getContext();
|
391 | 416 | patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
|
392 |
| - khr::WmmaStoreOpToSPIRVLowering>(converter, context); |
| 417 | + khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, |
| 418 | + WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); |
| 419 | + // Give the following patterns higher benefit to prevail over the default one. |
| 420 | + patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context, |
| 421 | + /*benefit=*/2); |
393 | 422 | }
|
394 | 423 |
|
395 | 424 | void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
|
396 | 425 | SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
397 | 426 | using namespace mlir;
|
398 | 427 | MLIRContext *context = patterns.getContext();
|
399 |
| - patterns |
400 |
| - .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering, |
401 |
| - nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering, |
402 |
| - nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); |
| 428 | + patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering, |
| 429 | + nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, |
| 430 | + WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); |
403 | 431 | // Give the following patterns higher benefit to prevail over the default one.
|
404 |
| - patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, |
405 |
| - context, |
406 |
| - /*benefit=*/2); |
| 432 | + patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context, |
| 433 | + /*benefit=*/2); |
407 | 434 | }
|
0 commit comments