Skip to content

Commit

Permalink
[MFMA][FRONTEND] Add more options for forced mfma layout sizes
Browse files Browse the repository at this point in the history
This PR:
- adds an `matrix_instr_nonkdim` options to force MFMA 64x4 and 4x64 layout: 464 corresponds 4(M)x64(N), 644 corresponds 64(M)x4(N)
- adds tests for this option
- fixes swizzling patter in some cases

MFMA size heuristic now looks like this:

1. If kernel specific option is set, pick it
2. If the result tile shape is larger than 32x32, pick mfma32
3. If the tile shape is smaller than 32x32 but larger than 16x16, pick mfma16
4. if the tile shape is smaller than 4x64 or 64x4, pick mfma4x4
5. Otherwise, pick mfma4x64 or mfma64x4, depending on what tile fits into matrices
  • Loading branch information
binarman committed Mar 19, 2024
1 parent b395044 commit da5040d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
10 changes: 10 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ compared to 1*64 when the hasLeadingOffset is false.

if (mfmaEnc) {
int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0;
int nonKDimNum = 1 - kDimNum;
if (needTrans)
kDimNum = 1 - kDimNum;
bool isKDimInner = (order[0] == kDimNum);
Expand All @@ -154,6 +155,15 @@ compared to 1*64 when the hasLeadingOffset is false.
auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim;
if (4 == nonKDim)
maxPhase = 4;
// if maxPhase * perPhase is larger than one block of warps,
// fallback to unswizzled tensor.
// Shared to dot op conversion requires that swizzling patern
// fits into one block of warps.
auto warpsPerCTA = mfmaEnc.getWarpsPerCTA();
if (maxPhase * perPhase > nonKDim * warpsPerCTA[nonKDimNum]) {
assert(isKDimInner);
maxPhase = 1;
}
assert(maxPhase > 0);

return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
Expand Down
16 changes: 14 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,20 @@ class BlockedToMFMA : public mlir::RewritePattern {
unsigned mDim = 0;
unsigned nDim = 0;
if (enforcedNonKDim != 0) {
mDim = enforcedNonKDim;
nDim = enforcedNonKDim;
if (enforcedNonKDim == 32 || enforcedNonKDim == 16 ||
enforcedNonKDim == 4) {
mDim = enforcedNonKDim;
nDim = enforcedNonKDim;
} else if (enforcedNonKDim == 464) {
mDim = 4;
nDim = 64;
} else if (enforcedNonKDim == 644) {
mDim = 64;
nDim = 4;
} else {
llvm::report_fatal_error("Invalid MFMA nonKDim option, supported "
"values are: 32, 16, 4, 464, 644");
}
} else {
int minSize = std::min(resShape[0], resShape[1]);
if (minSize >= 32) {
Expand Down
18 changes: 17 additions & 1 deletion python/test/unit/language/test_core_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,19 @@ def kernel(X, stride_xm, stride_xn,
for non_k_dim in [0, 4, 16, 32]
if not (allow_tf32 and (in_dtype in ['float16']))] +
[(*shape, warps, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, 1)
for shape in [(64, 16, 128), (16, 64, 128)]
for warps in [1, 4]
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
for allow_tf32 in [False]
for in_dtype, out_dtype in [('float16', 'float16'),
('bfloat16', 'float32'),
('float8e5m2fnuz', 'float32'),
('float8e4m3fnuz', 'float32'),
('float16', 'float32'),
('float32', 'float32')]
for non_k_dim in [464, 644]] +
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim, kpack)
for shape_nw in [[128, 128, 32, 2],
[128, 16, 32, 4],
Expand Down Expand Up @@ -1728,6 +1741,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
pytest.skip("incompatible non_k_dim == 4 with K size")
if non_k_dim == 4 and (M > 16 or N > 16):
pytest.skip("skipping large matrices for non_k_dim == 4 to speedup testing")
if (non_k_dim == 464 and N < 64) or (non_k_dim == 644 and M < 64):
pytest.skip(f"skipping non_k_dim={non_k_dim} specific test with incompatible matrix sizes")


if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
Expand Down Expand Up @@ -1852,7 +1868,7 @@ def kernel(X, stride_xm, stride_xk,

z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
z_tri = torch.as_strided(z_tri, (M, N), [1, M])

if out_dtype == 'int8':
out_dtype = tl.int8
Expand Down

0 comments on commit da5040d

Please # to comment.