diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c8fdc1c70f5c..54bc607a4ede 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -131,6 +131,7 @@ compared to 1*64 when the hasLeadingOffset is false. if (mfmaEnc) { int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + int nonKDimNum = 1 - kDimNum; if (needTrans) kDimNum = 1 - kDimNum; bool isKDimInner = (order[0] == kDimNum); @@ -154,6 +155,15 @@ compared to 1*64 when the hasLeadingOffset is false. auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim; if (4 == nonKDim) maxPhase = 4; + // if maxPhase * perPhase is larger than one block of warps, + // fallback to unswizzled tensor. + // Shared to dot op conversion requires that swizzling patern + // fits into one block of warps. + auto warpsPerCTA = mfmaEnc.getWarpsPerCTA(); + if (maxPhase * perPhase > nonKDim * warpsPerCTA[nonKDimNum]) { + assert(isKDimInner); + maxPhase = 1; + } assert(maxPhase > 0); return get(context, vecSize, perPhase, maxPhase, order, CTALayout); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 12fdbf23e4a4..3f39248597bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -175,8 +175,20 @@ class BlockedToMFMA : public mlir::RewritePattern { unsigned mDim = 0; unsigned nDim = 0; if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; + if (enforcedNonKDim == 32 || enforcedNonKDim == 16 || + enforcedNonKDim == 4) { + mDim = enforcedNonKDim; + nDim = enforcedNonKDim; + } else if (enforcedNonKDim == 464) { + mDim = 4; + nDim = 64; + } else if (enforcedNonKDim == 644) { + mDim = 64; + nDim = 4; + } else { + llvm::report_fatal_error("Invalid MFMA nonKDim option, supported " + "values are: 32, 16, 4, 464, 644"); + } } else { int minSize = std::min(resShape[0], resShape[1]); if (minSize >= 32) { diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 2bf5c63dd613..0a451d539453 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1665,6 +1665,19 @@ def kernel(X, stride_xm, stride_xn, for non_k_dim in [0, 4, 16, 32] if not (allow_tf32 and (in_dtype in ['float16']))] + + [(*shape, warps, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, 1) + for shape in [(64, 16, 128), (16, 64, 128)] + for warps in [1, 4] + for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax'] + for allow_tf32 in [False] + for in_dtype, out_dtype in [('float16', 'float16'), + ('bfloat16', 'float32'), + ('float8e5m2fnuz', 'float32'), + ('float8e4m3fnuz', 'float32'), + ('float16', 'float32'), + ('float32', 'float32')] + for non_k_dim in [464, 644]] + + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim, kpack) for shape_nw in [[128, 128, 32, 2], [128, 16, 32, 4], @@ -1728,6 +1741,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o pytest.skip("incompatible non_k_dim == 4 with K size") if non_k_dim == 4 and (M > 16 or N > 16): pytest.skip("skipping large matrices for non_k_dim == 4 to speedup testing") + if (non_k_dim == 464 and N < 64) or (non_k_dim == 644 and M < 64): + pytest.skip(f"skipping non_k_dim={non_k_dim} specific test with incompatible matrix sizes") + if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -1852,7 +1868,7 @@ def kernel(X, stride_xm, stride_xk, z_tri = to_triton(z, device=device) if epilogue == 'trans': - z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) if out_dtype == 'int8': out_dtype = tl.int8