@@ -338,253 +338,6 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
338
338
];
339
339
}
340
340
341
- //===----------------------------------------------------------------------===//
342
- // SPV_NV_cooperative_matrix extension ops.
343
- //===----------------------------------------------------------------------===//
344
-
345
- // -----
346
-
347
- def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength",
348
- [Pure]> {
349
- let summary = "See extension SPV_NV_cooperative_matrix";
350
-
351
- let description = [{
352
- Number of components of a cooperative matrix type accessible to each
353
- invocation when treated as a composite.
354
-
355
- Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness.
356
-
357
- Type is a cooperative matrix type.
358
-
359
- #### Example:
360
-
361
- ```
362
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
363
- ```
364
- }];
365
-
366
- let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
367
-
368
- let availability = [
369
- MinVersion<SPIRV_V_1_0>,
370
- MaxVersion<SPIRV_V_1_6>,
371
- Extension<[SPV_NV_cooperative_matrix]>,
372
- Capability<[SPIRV_C_CooperativeMatrixNV]>
373
- ];
374
-
375
- let arguments = (ins
376
- TypeAttr:$cooperative_matrix_type
377
- );
378
-
379
- let results = (outs
380
- SPIRV_Int32:$result
381
- );
382
- }
383
-
384
- // -----
385
-
386
- def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad", []> {
387
- let summary = "See extension SPV_NV_cooperative_matrix";
388
-
389
- let description = [{
390
- Load a cooperative matrix through a pointer.
391
-
392
- Result Type is the type of the loaded object. It must be a cooperative
393
- matrix type.
394
-
395
- Pointer is a pointer into an array. Its type must be an OpTypePointer whose
396
- Type operand is a scalar or vector type. The storage class of Pointer must
397
- be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
398
- supported) PhysicalStorageBufferEXT.
399
-
400
- Stride is the number of elements in the array in memory between the first
401
- component of consecutive rows (or columns) in the result. It must be a
402
- scalar integer type.
403
-
404
- ColumnMajor indicates whether the values loaded from memory are arranged in
405
- column-major or row-major order. It must be a boolean constant instruction,
406
- with false indicating row major and true indicating column major.
407
-
408
- Memory Access must be a Memory Access literal. If not present, it is the
409
- same as specifying None.
410
-
411
- If ColumnMajor is false, then elements (row,*) of the result are taken in
412
- order from contiguous locations starting at Pointer[row*Stride]. If
413
- ColumnMajor is true, then elements (*,col) of the result are taken in order
414
- from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
415
- decoration on Pointer is ignored.
416
-
417
- For a given dynamic instance of this instruction, all operands of this
418
- instruction must be the same for all invocations in a given scope instance
419
- (where the scope is the scope the cooperative matrix type was created with).
420
- All invocations in a given scope instance must be active or all must be
421
- inactive.
422
-
423
- ### Custom assembly form
424
-
425
- ``` {.ebnf}
426
- cooperative-matrixload-op ::= ssa-id `=` `spirv.NV.CooperativeMatrixLoad`
427
- ssa-use `,` ssa-use `,` ssa-use
428
- (`[` memory-access `]`)? ` : `
429
- pointer-type `as`
430
- cooperative-matrix-type
431
- ```
432
-
433
- #### Example:
434
-
435
- ```
436
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor
437
- : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
438
- ```
439
- }];
440
-
441
- let availability = [
442
- MinVersion<SPIRV_V_1_0>,
443
- MaxVersion<SPIRV_V_1_6>,
444
- Extension<[SPV_NV_cooperative_matrix]>,
445
- Capability<[SPIRV_C_CooperativeMatrixNV]>
446
- ];
447
-
448
- let arguments = (ins
449
- SPIRV_AnyPtr:$pointer,
450
- SPIRV_Integer:$stride,
451
- SPIRV_Bool:$columnmajor,
452
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
453
- );
454
-
455
- let results = (outs
456
- SPIRV_AnyCooperativeMatrixNV:$result
457
- );
458
- }
459
-
460
- // -----
461
-
462
- def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAdd",
463
- [Pure, AllTypesMatch<["c", "result"]>]> {
464
- let summary = "See extension SPV_NV_cooperative_matrix";
465
-
466
- let description = [{
467
- Linear-algebraic matrix multiply of A by B and then component-wise add C.
468
- The order of the operations is implementation-dependent. The internal
469
- precision of floating-point operations is defined by the client API.
470
- Integer operations are performed at the precision of the Result Type and are
471
- exact unless there is overflow or underflow, in which case the result is
472
- undefined.
473
-
474
- Result Type must be a cooperative matrix type with M rows and N columns.
475
-
476
- A is a cooperative matrix with M rows and K columns.
477
-
478
- B is a cooperative matrix with K rows and N columns.
479
-
480
- C is a cooperative matrix with M rows and N columns.
481
-
482
- The values of M, N, and K must be consistent across the result and operands.
483
- This is referred to as an MxNxK matrix multiply.
484
-
485
- A, B, C, and Result Type must have the same scope, and this defines the
486
- scope of the operation. A, B, C, and Result Type need not necessarily have
487
- the same component type, this is defined by the client API.
488
-
489
- If the Component Type of any matrix operand is an integer type, then its
490
- components are treated as signed if its Component Type has Signedness of 1
491
- and are treated as unsigned otherwise.
492
-
493
- For a given dynamic instance of this instruction, all invocations in a given
494
- scope instance must be active or all must be inactive (where the scope is
495
- the scope of the operation).
496
-
497
- #### Example:
498
-
499
- ```
500
- %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, :
501
- !spirv.NV.coopmatrix<8x16xi32, Subgroup>
502
- ```
503
- }];
504
-
505
- let assemblyFormat = [{
506
- operands attr-dict `:` type($a) `,` type($b) `->` type($c)
507
- }];
508
-
509
- let availability = [
510
- MinVersion<SPIRV_V_1_0>,
511
- MaxVersion<SPIRV_V_1_6>,
512
- Extension<[SPV_NV_cooperative_matrix]>,
513
- Capability<[SPIRV_C_CooperativeMatrixNV]>
514
- ];
515
-
516
- let arguments = (ins
517
- SPIRV_AnyCooperativeMatrixNV:$a,
518
- SPIRV_AnyCooperativeMatrixNV:$b,
519
- SPIRV_AnyCooperativeMatrixNV:$c
520
- );
521
-
522
- let results = (outs
523
- SPIRV_AnyCooperativeMatrixNV:$result
524
- );
525
- }
526
-
527
- // -----
528
-
529
- def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore", []> {
530
- let summary = "See extension SPV_NV_cooperative_matrix";
531
-
532
- let description = [{
533
- Store a cooperative matrix through a pointer.
534
-
535
- Pointer is a pointer into an array. Its type must be an OpTypePointer whose
536
- Type operand is a scalar or vector type. The storage class of Pointer must
537
- be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
538
- supported) PhysicalStorageBufferEXT.
539
-
540
- Object is the object to store. Its type must be an
541
- OpTypeCooperativeMatrixNV.
542
-
543
- Stride is the number of elements in the array in memory between the first
544
- component of consecutive rows (or columns) in the result. It must be a
545
- scalar integer type.
546
-
547
- ColumnMajor indicates whether the values stored to memory are arranged in
548
- column-major or row-major order. It must be a boolean constant instruction,
549
- with false indicating row major and true indicating column major.
550
-
551
- Memory Access must be a Memory Access literal. If not present, it is the
552
- same as specifying None.
553
-
554
- ``` {.ebnf}
555
- coop-matrix-store-op ::= `spirv.NV.CooperativeMatrixStore `
556
- ssa-use `, ` ssa-use `, `
557
- ssa-use `, ` ssa-use `, `
558
- (`[` memory-access `]`)? `:`
559
- pointer-type `,` coop-matrix-type
560
- ```
561
-
562
- #### Example:
563
-
564
- ```
565
- spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 :
566
- !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
567
- ```
568
- }];
569
-
570
- let availability = [
571
- MinVersion<SPIRV_V_1_0>,
572
- MaxVersion<SPIRV_V_1_6>,
573
- Extension<[SPV_NV_cooperative_matrix]>,
574
- Capability<[SPIRV_C_CooperativeMatrixNV]>
575
- ];
576
-
577
- let arguments = (ins
578
- SPIRV_AnyPtr:$pointer,
579
- SPIRV_AnyCooperativeMatrixNV:$object,
580
- SPIRV_Integer:$stride,
581
- SPIRV_Bool:$columnmajor,
582
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
583
- );
584
-
585
- let results = (outs);
586
- }
587
-
588
341
// -----
589
342
590
343
#endif // MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
0 commit comments