Skip to content

Commit 6e90f13

Browse files
authored
[mlir][spirv] Drop support for SPV_NV_cooperative_matrix (#76782)
This extension has been superseded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD, and Intel. Given that the KHR version has been supported for nearly half a year, drop the NV-specific extension to reduce the maintenance burden and code duplication.
1 parent 0fe86f9 commit 6e90f13

26 files changed

+49
-1382
lines changed

mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h

+3-9
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,10 @@ void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
3131
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
3232
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
3333

34-
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
35-
/// using the NV Cooperative Matrix extension.
36-
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
37-
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
38-
39-
/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type
40-
/// conversion to the type converter. Defaults to KHR cooperative matrix types.
41-
/// When `useNVTypes` is `true`, uses the NV cooperative matrix types.
34+
/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix KHR type
35+
/// conversion to the type converter.
4236
void populateMMAToSPIRVCoopMatrixTypeConversion(
43-
SPIRVTypeConverter &typeConverter, bool useNVTypes = false);
37+
SPIRVTypeConverter &typeConverter);
4438
} // namespace mlir
4539

4640
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H

mlir/include/mlir/Conversion/Passes.td

-4
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,6 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
564564
Option<"use64bitIndex", "use-64bit-index",
565565
"bool", /*default=*/"false",
566566
"Use 64-bit integers to convert index types">,
567-
Option<"useCoopMatrixNV", "use-coop-matrix-nv",
568-
"bool", /*default=*/"false",
569-
"Use the NV cooperative matrix extension insted of the KHR extension"
570-
" to lower GPU WMMA ops">,
571567
];
572568
}
573569

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

+6-32
Original file line numberDiff line numberDiff line change
@@ -1253,12 +1253,6 @@ def SPIRV_C_RayTracingProvisionalKHR : I32EnumAttrCase<"RayTr
12531253
Extension<[SPV_KHR_ray_tracing]>
12541254
];
12551255
}
1256-
def SPIRV_C_CooperativeMatrixNV : I32EnumAttrCase<"CooperativeMatrixNV", 5357> {
1257-
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
1258-
list<Availability> availability = [
1259-
Extension<[SPV_NV_cooperative_matrix]>
1260-
];
1261-
}
12621256
def SPIRV_C_FragmentShaderSampleInterlockEXT : I32EnumAttrCase<"FragmentShaderSampleInterlockEXT", 5363> {
12631257
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
12641258
list<Availability> availability = [
@@ -1501,7 +1495,7 @@ def SPIRV_CapabilityAttr :
15011495
SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray,
15021496
SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV,
15031497
SPIRV_C_RayTracingMotionBlurNV, SPIRV_C_PhysicalStorageBufferAddresses,
1504-
SPIRV_C_RayTracingProvisionalKHR, SPIRV_C_CooperativeMatrixNV,
1498+
SPIRV_C_RayTracingProvisionalKHR,
15051499
SPIRV_C_FragmentShaderSampleInterlockEXT,
15061500
SPIRV_C_FragmentShaderShadingRateInterlockEXT, SPIRV_C_ShaderSMBuiltinsNV,
15071501
SPIRV_C_FragmentShaderPixelInterlockEXT, SPIRV_C_DemoteToHelperInvocation,
@@ -4123,8 +4117,6 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
41234117
def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
41244118
def SPIRV_IsCooperativeMatrixType :
41254119
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
4126-
def SPIRV_IsCooperativeMatrixNVType :
4127-
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
41284120
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
41294121
def SPIRV_IsJointMatrixType :
41304122
CPred<"::llvm::isa<::mlir::spirv::JointMatrixINTELType>($_self)">;
@@ -4157,9 +4149,6 @@ def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
41574149
def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
41584150
SPIRV_IsCooperativeMatrixType,
41594151
"any SPIR-V cooperative matrix type">;
4160-
def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
4161-
SPIRV_IsCooperativeMatrixNVType,
4162-
"any SPIR-V NV cooperative matrix type">;
41634152
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
41644153
"any SPIR-V image type">;
41654154
def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
@@ -4178,13 +4167,12 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
41784167
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
41794168
def SPIRV_Composite :
41804169
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
4181-
SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
4182-
SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
4170+
SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
41834171
def SPIRV_Type : AnyTypeOf<[
41844172
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
41854173
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
4186-
SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
4187-
SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
4174+
SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
4175+
SPIRV_AnySampledImage
41884176
]>;
41894177

41904178
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4195,11 +4183,6 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
41954183
"::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
41964184
"Cooperative Matrix">;
41974185

4198-
class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
4199-
ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
4200-
"::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
4201-
"Cooperative Matrix NV">;
4202-
42034186
class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
42044187
ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
42054188
"::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
@@ -4213,12 +4196,11 @@ class SPIRV_ScalarOrVectorOf<Type type> :
42134196

42144197
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
42154198
AnyTypeOf<[type, SPIRV_VectorOf<type>,
4216-
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
4199+
SPIRV_CoopMatrixOfType<[type]>]>;
42174200

42184201
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
42194202
AnyTypeOf<[SPIRV_AnyMatrix,
4220-
SPIRV_CoopMatrixOfType<[type]>,
4221-
SPIRV_CoopMatrixNVOfType<[type]>]>;
4203+
SPIRV_CoopMatrixOfType<[type]>]>;
42224204

42234205
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
42244206
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
@@ -4480,11 +4462,6 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrix
44804462
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
44814463
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
44824464
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
4483-
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
4484-
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
4485-
def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
4486-
def SPIRV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
4487-
def SPIRV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
44884465
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
44894466
def SPIRV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
44904467
def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
@@ -4585,9 +4562,6 @@ def SPIRV_OpcodeAttr :
45854562
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
45864563
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
45874564
SPIRV_OC_OpCooperativeMatrixLengthKHR,
4588-
SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
4589-
SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
4590-
SPIRV_OC_OpCooperativeMatrixLengthNV,
45914565
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
45924566
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
45934567
SPIRV_OC_OpGroupFMulKHR,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

-247
Original file line numberDiff line numberDiff line change
@@ -338,253 +338,6 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
338338
];
339339
}
340340

341-
//===----------------------------------------------------------------------===//
342-
// SPV_NV_cooperative_matrix extension ops.
343-
//===----------------------------------------------------------------------===//
344-
345-
// -----
346-
347-
def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength",
348-
[Pure]> {
349-
let summary = "See extension SPV_NV_cooperative_matrix";
350-
351-
let description = [{
352-
Number of components of a cooperative matrix type accessible to each
353-
invocation when treated as a composite.
354-
355-
Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness.
356-
357-
Type is a cooperative matrix type.
358-
359-
#### Example:
360-
361-
```
362-
%0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
363-
```
364-
}];
365-
366-
let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
367-
368-
let availability = [
369-
MinVersion<SPIRV_V_1_0>,
370-
MaxVersion<SPIRV_V_1_6>,
371-
Extension<[SPV_NV_cooperative_matrix]>,
372-
Capability<[SPIRV_C_CooperativeMatrixNV]>
373-
];
374-
375-
let arguments = (ins
376-
TypeAttr:$cooperative_matrix_type
377-
);
378-
379-
let results = (outs
380-
SPIRV_Int32:$result
381-
);
382-
}
383-
384-
// -----
385-
386-
def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad", []> {
387-
let summary = "See extension SPV_NV_cooperative_matrix";
388-
389-
let description = [{
390-
Load a cooperative matrix through a pointer.
391-
392-
Result Type is the type of the loaded object. It must be a cooperative
393-
matrix type.
394-
395-
Pointer is a pointer into an array. Its type must be an OpTypePointer whose
396-
Type operand is a scalar or vector type. The storage class of Pointer must
397-
be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
398-
supported) PhysicalStorageBufferEXT.
399-
400-
Stride is the number of elements in the array in memory between the first
401-
component of consecutive rows (or columns) in the result. It must be a
402-
scalar integer type.
403-
404-
ColumnMajor indicates whether the values loaded from memory are arranged in
405-
column-major or row-major order. It must be a boolean constant instruction,
406-
with false indicating row major and true indicating column major.
407-
408-
Memory Access must be a Memory Access literal. If not present, it is the
409-
same as specifying None.
410-
411-
If ColumnMajor is false, then elements (row,*) of the result are taken in
412-
order from contiguous locations starting at Pointer[row*Stride]. If
413-
ColumnMajor is true, then elements (*,col) of the result are taken in order
414-
from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
415-
decoration on Pointer is ignored.
416-
417-
For a given dynamic instance of this instruction, all operands of this
418-
instruction must be the same for all invocations in a given scope instance
419-
(where the scope is the scope the cooperative matrix type was created with).
420-
All invocations in a given scope instance must be active or all must be
421-
inactive.
422-
423-
### Custom assembly form
424-
425-
``` {.ebnf}
426-
cooperative-matrixload-op ::= ssa-id `=` `spirv.NV.CooperativeMatrixLoad`
427-
ssa-use `,` ssa-use `,` ssa-use
428-
(`[` memory-access `]`)? ` : `
429-
pointer-type `as`
430-
cooperative-matrix-type
431-
```
432-
433-
#### Example:
434-
435-
```
436-
%0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor
437-
: !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
438-
```
439-
}];
440-
441-
let availability = [
442-
MinVersion<SPIRV_V_1_0>,
443-
MaxVersion<SPIRV_V_1_6>,
444-
Extension<[SPV_NV_cooperative_matrix]>,
445-
Capability<[SPIRV_C_CooperativeMatrixNV]>
446-
];
447-
448-
let arguments = (ins
449-
SPIRV_AnyPtr:$pointer,
450-
SPIRV_Integer:$stride,
451-
SPIRV_Bool:$columnmajor,
452-
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
453-
);
454-
455-
let results = (outs
456-
SPIRV_AnyCooperativeMatrixNV:$result
457-
);
458-
}
459-
460-
// -----
461-
462-
def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAdd",
463-
[Pure, AllTypesMatch<["c", "result"]>]> {
464-
let summary = "See extension SPV_NV_cooperative_matrix";
465-
466-
let description = [{
467-
Linear-algebraic matrix multiply of A by B and then component-wise add C.
468-
The order of the operations is implementation-dependent. The internal
469-
precision of floating-point operations is defined by the client API.
470-
Integer operations are performed at the precision of the Result Type and are
471-
exact unless there is overflow or underflow, in which case the result is
472-
undefined.
473-
474-
Result Type must be a cooperative matrix type with M rows and N columns.
475-
476-
A is a cooperative matrix with M rows and K columns.
477-
478-
B is a cooperative matrix with K rows and N columns.
479-
480-
C is a cooperative matrix with M rows and N columns.
481-
482-
The values of M, N, and K must be consistent across the result and operands.
483-
This is referred to as an MxNxK matrix multiply.
484-
485-
A, B, C, and Result Type must have the same scope, and this defines the
486-
scope of the operation. A, B, C, and Result Type need not necessarily have
487-
the same component type, this is defined by the client API.
488-
489-
If the Component Type of any matrix operand is an integer type, then its
490-
components are treated as signed if its Component Type has Signedness of 1
491-
and are treated as unsigned otherwise.
492-
493-
For a given dynamic instance of this instruction, all invocations in a given
494-
scope instance must be active or all must be inactive (where the scope is
495-
the scope of the operation).
496-
497-
#### Example:
498-
499-
```
500-
%0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, :
501-
!spirv.NV.coopmatrix<8x16xi32, Subgroup>
502-
```
503-
}];
504-
505-
let assemblyFormat = [{
506-
operands attr-dict `:` type($a) `,` type($b) `->` type($c)
507-
}];
508-
509-
let availability = [
510-
MinVersion<SPIRV_V_1_0>,
511-
MaxVersion<SPIRV_V_1_6>,
512-
Extension<[SPV_NV_cooperative_matrix]>,
513-
Capability<[SPIRV_C_CooperativeMatrixNV]>
514-
];
515-
516-
let arguments = (ins
517-
SPIRV_AnyCooperativeMatrixNV:$a,
518-
SPIRV_AnyCooperativeMatrixNV:$b,
519-
SPIRV_AnyCooperativeMatrixNV:$c
520-
);
521-
522-
let results = (outs
523-
SPIRV_AnyCooperativeMatrixNV:$result
524-
);
525-
}
526-
527-
// -----
528-
529-
def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore", []> {
530-
let summary = "See extension SPV_NV_cooperative_matrix";
531-
532-
let description = [{
533-
Store a cooperative matrix through a pointer.
534-
535-
Pointer is a pointer into an array. Its type must be an OpTypePointer whose
536-
Type operand is a scalar or vector type. The storage class of Pointer must
537-
be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
538-
supported) PhysicalStorageBufferEXT.
539-
540-
Object is the object to store. Its type must be an
541-
OpTypeCooperativeMatrixNV.
542-
543-
Stride is the number of elements in the array in memory between the first
544-
component of consecutive rows (or columns) in the result. It must be a
545-
scalar integer type.
546-
547-
ColumnMajor indicates whether the values stored to memory are arranged in
548-
column-major or row-major order. It must be a boolean constant instruction,
549-
with false indicating row major and true indicating column major.
550-
551-
Memory Access must be a Memory Access literal. If not present, it is the
552-
same as specifying None.
553-
554-
``` {.ebnf}
555-
coop-matrix-store-op ::= `spirv.NV.CooperativeMatrixStore `
556-
ssa-use `, ` ssa-use `, `
557-
ssa-use `, ` ssa-use `, `
558-
(`[` memory-access `]`)? `:`
559-
pointer-type `,` coop-matrix-type
560-
```
561-
562-
#### Example:
563-
564-
```
565-
spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 :
566-
!spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
567-
```
568-
}];
569-
570-
let availability = [
571-
MinVersion<SPIRV_V_1_0>,
572-
MaxVersion<SPIRV_V_1_6>,
573-
Extension<[SPV_NV_cooperative_matrix]>,
574-
Capability<[SPIRV_C_CooperativeMatrixNV]>
575-
];
576-
577-
let arguments = (ins
578-
SPIRV_AnyPtr:$pointer,
579-
SPIRV_AnyCooperativeMatrixNV:$object,
580-
SPIRV_Integer:$stride,
581-
SPIRV_Bool:$columnmajor,
582-
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
583-
);
584-
585-
let results = (outs);
586-
}
587-
588341
// -----
589342

590343
#endif // MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS

0 commit comments

Comments
 (0)