From da5040d5a5010becc7a83ea6466dd2913e17beef Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 25 Jan 2024 13:20:40 +0000 Subject: [PATCH] [MFMA][FRONTEND] Add more options for forced mfma layout sizes This PR: - adds an `matrix_instr_nonkdim` options to force MFMA 64x4 and 4x64 layout: 464 corresponds 4(M)x64(N), 644 corresponds 64(M)x4(N) - adds tests for this option - fixes swizzling patter in some cases MFMA size heuristic now looks like this: 1. If kernel specific option is set, pick it 2. If the result tile shape is larger than 32x32, pick mfma32 3. If the tile shape is smaller than 32x32 but larger than 16x16, pick mfma16 4. if the tile shape is smaller than 4x64 or 64x4, pick mfma4x4 5. Otherwise, pick mfma4x64 or mfma64x4, depending on what tile fits into matrices --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 10 ++++++++++ .../Transforms/AccelerateAMDMatmul.cpp | 16 ++++++++++++++-- python/test/unit/language/test_core_amd.py | 18 +++++++++++++++++- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c8fdc1c70f5c..54bc607a4ede 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -131,6 +131,7 @@ compared to 1*64 when the hasLeadingOffset is false. if (mfmaEnc) { int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + int nonKDimNum = 1 - kDimNum; if (needTrans) kDimNum = 1 - kDimNum; bool isKDimInner = (order[0] == kDimNum); @@ -154,6 +155,15 @@ compared to 1*64 when the hasLeadingOffset is false. auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim; if (4 == nonKDim) maxPhase = 4; + // if maxPhase * perPhase is larger than one block of warps, + // fallback to unswizzled tensor. + // Shared to dot op conversion requires that swizzling patern + // fits into one block of warps. + auto warpsPerCTA = mfmaEnc.getWarpsPerCTA(); + if (maxPhase * perPhase > nonKDim * warpsPerCTA[nonKDimNum]) { + assert(isKDimInner); + maxPhase = 1; + } assert(maxPhase > 0); return get(context, vecSize, perPhase, maxPhase, order, CTALayout); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 12fdbf23e4a4..3f39248597bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -175,8 +175,20 @@ class BlockedToMFMA : public mlir::RewritePattern { unsigned mDim = 0; unsigned nDim = 0; if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; + if (enforcedNonKDim == 32 || enforcedNonKDim == 16 || + enforcedNonKDim == 4) { + mDim = enforcedNonKDim; + nDim = enforcedNonKDim; + } else if (enforcedNonKDim == 464) { + mDim = 4; + nDim = 64; + } else if (enforcedNonKDim == 644) { + mDim = 64; + nDim = 4; + } else { + llvm::report_fatal_error("Invalid MFMA nonKDim option, supported " + "values are: 32, 16, 4, 464, 644"); + } } else { int minSize = std::min(resShape[0], resShape[1]); if (minSize >= 32) { diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 2bf5c63dd613..0a451d539453 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1665,6 +1665,19 @@ def kernel(X, stride_xm, stride_xn, for non_k_dim in [0, 4, 16, 32] if not (allow_tf32 and (in_dtype in ['float16']))] + + [(*shape, warps, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, 1) + for shape in [(64, 16, 128), (16, 64, 128)] + for warps in [1, 4] + for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax'] + for allow_tf32 in [False] + for in_dtype, out_dtype in [('float16', 'float16'), + ('bfloat16', 'float32'), + ('float8e5m2fnuz', 'float32'), + ('float8e4m3fnuz', 'float32'), + ('float16', 'float32'), + ('float32', 'float32')] + for non_k_dim in [464, 644]] + + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim, kpack) for shape_nw in [[128, 128, 32, 2], [128, 16, 32, 4], @@ -1728,6 +1741,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o pytest.skip("incompatible non_k_dim == 4 with K size") if non_k_dim == 4 and (M > 16 or N > 16): pytest.skip("skipping large matrices for non_k_dim == 4 to speedup testing") + if (non_k_dim == 464 and N < 64) or (non_k_dim == 644 and M < 64): + pytest.skip(f"skipping non_k_dim={non_k_dim} specific test with incompatible matrix sizes") + if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -1852,7 +1868,7 @@ def kernel(X, stride_xm, stride_xk, z_tri = to_triton(z, device=device) if epilogue == 'trans': - z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) if out_dtype == 'int8': out_dtype = tl.int8