Skip to content

Commit

Permalink
[MFMA] MFMA 4x64 64x4 version 2
Browse files Browse the repository at this point in the history
Extend K dimension of mfma4x64 and mfma64x4 dot operand layout from 4 to 64.
  • Loading branch information
binarman committed Mar 19, 2024
1 parent da5040d commit fde46d8
Show file tree
Hide file tree
Showing 12 changed files with 667 additions and 344 deletions.
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ struct MfmaInsnAttr {
unsigned n;
unsigned k;
// k_base refers to the number of elements per thread
unsigned k_base;
unsigned k_base_a;
unsigned k_base_b;
llvm::StringRef insn;
};

Expand Down Expand Up @@ -223,7 +224,8 @@ class MfmaInsn {
unsigned getMDim();
unsigned getNDim();
StringRef getInsnName();
unsigned getKBase();
unsigned getKBaseA();
unsigned getKBaseB();
};
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16 ||
(mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64)) &&
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,12 @@ llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else {
// In this configuration wave contains 16 copies of same data
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) {
// shortcut for 64x64 tile size.
// In this case warp do not wrap, so no need to introduce this offset
if (iNonKDim == 64)
laneHOffset = i32_val(0);
} else {
assert(iKDim * iNonKDim / numOfElems == 64 &&
"seems no all threads in wave contain unique elements");
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
}
}

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Expand Down Expand Up @@ -346,7 +344,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
Value halfOffset;
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4)
if (iNonKDim == 64)
halfOffset = i32_val(0);
else
halfOffset =
Expand Down Expand Up @@ -456,6 +454,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int numSubBlocks = 1;
if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4)
numSubBlocks = 16;
assert(numSubBlocks == 1 &&
"after reworking layout, there should be no redundency");
int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize;
assert(numOfElems >= 1);

Expand Down
234 changes: 174 additions & 60 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;

using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
// mapping from touple <kpack-rep, non-k-rep, k-rep> to vector of values
// vector contains single element for MFMA32, MFMA16 and MFMA4 layouts
// for MFMA 4x64 and 64x4 layouts there are 16 vectors for one of the arguments,
// because each repetition in these layouts requires 16 mfma operations
using ValueTable = std::map<std::tuple<unsigned, unsigned, unsigned>,
llvm::SmallVector<Value>>;

struct DotOpMFMAConversionHelper {
MfmaEncodingAttr mfmaLayout;
Expand All @@ -60,16 +65,114 @@ struct DotOpMFMAConversionHelper {
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}

/**
* @param mfmaInsnName
* @param valA
* @param valB
* @param valC
* @param cbsz Control Broadcast Size modifier
* @param abid A-matrix Broadcast Identifier
* @param blgp B-matrix Lane Group Pattern modifier
*/
Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB,
Value valC) const {
Value valC, int cbsz = 0, int abid = 0,
int blgp = 0) const {
assert(cbsz >= 0 && cbsz <= 4);
assert(abid >= 0 && abid <= 15);
assert(blgp >= 0 && blgp <= 7);
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
Value zeroVal = i32_val(0);
Value cbszFlag = cbsz != 0 ? i32_val(cbsz) : zeroVal;
Value abidFlag = abid != 0 ? i32_val(abid) : zeroVal;
Value blgpFlag = blgp != 0 ? i32_val(blgp) : zeroVal;
OperationState loweredOp(loc, mfmaInsnName);
loweredOp.addTypes(resType);
loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
loweredOp.addOperands({valA, valB, valC, cbszFlag, abidFlag, blgpFlag});
return rewriter.create(loweredOp)->getResult(0);
}

Value broadcastGroup(Value val, int groupId, int numGroups) const {
constexpr int waveSize = 64;
const int groupSize = waveSize / numGroups;

Value lane = getThreadId();
// Multiply by 4, because permute requires offset in bytes
Value laneOffset = mul(urem(lane, i32_val(groupSize)), i32_val(4));
Value permuteAddr = add(laneOffset, i32_val(groupId * groupSize * 4));
Type valType = val.getType();
Value broadcasted;
if (valType.isInteger(32))
broadcasted = rewriter.create<ROCDL::DsBpermuteOp>(loc, val.getType(),
permuteAddr, val);
if (valType.isF32()) {
val = bitcast(val, i32_ty);
broadcasted = rewriter.create<ROCDL::DsBpermuteOp>(loc, val.getType(),
permuteAddr, val);
broadcasted = bitcast(broadcasted, f32_ty);
}
if (valType.isa<VectorType>()) {
auto vecTy = valType.dyn_cast<VectorType>();
auto vecBitSize = vecTy.getElementType().getIntOrFloatBitWidth() *
vecTy.getNumElements();
const int int32VecSize = vecBitSize / 32;

Type int32VecTy = vec_ty(i32_ty, int32VecSize);
Value int32Val = bitcast(val, int32VecTy);
Value int32Broadcasted = undef(int32VecTy);
for (int i = 0; i < int32VecSize; ++i) {
Value int32Chunk = extract_element(i32_ty, int32Val, i32_val(i));
Value broadcastedChunk = rewriter.create<ROCDL::DsBpermuteOp>(
loc, i32_ty, permuteAddr, int32Chunk);
int32Broadcasted = insert_element(int32VecTy, int32Broadcasted,
broadcastedChunk, i32_val(i));
}
broadcasted = bitcast(int32Broadcasted, valType);
}
assert(broadcasted);
return broadcasted;
}

Value generateMFMATile(StringRef mfmaInsnName, SmallVector<Value> valA,
SmallVector<Value> valB, Value valC, int mDim,
int nDim, bool transpose) const {

Value acc;
if (mDim == nDim) {
assert(valA.size() == 1 && valB.size() == 1);
acc = transpose ? generateMFMAOp(mfmaInsnName, valB[0], valA[0], valC)
: generateMFMAOp(mfmaInsnName, valA[0], valB[0], valC);
}
if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) {
// broadcast selected kRep A operand matrix to all A matrices(2^4=16)
constexpr int broadcastCtrl = 4;
constexpr int numRepeats = 16;
acc = valC;
for (int kRep = 0; kRep < numRepeats; kRep++) {
if (mDim == 4 && !transpose) {
assert(valA.size() == 1 && valB.size() == 16);
acc = generateMFMAOp(mfmaInsnName, valA[0], valB[kRep], acc,
broadcastCtrl, kRep);
}
if (mDim == 4 && transpose) {
assert(valA.size() == 1 && valB.size() == 16);
Value broadcastValA = broadcastGroup(valA[0], kRep, numRepeats);
acc = generateMFMAOp(mfmaInsnName, valB[kRep], broadcastValA, acc);
}
if (nDim == 4 && !transpose) {
assert(valA.size() == 16 && valB.size() == 1);
Value broadcastValB = broadcastGroup(valB[0], kRep, numRepeats);
acc = generateMFMAOp(mfmaInsnName, valA[kRep], broadcastValB, acc);
}
if (nDim == 4 && transpose) {
assert(valA.size() == 16 && valB.size() == 1);
acc = generateMFMAOp(mfmaInsnName, valB[0], valA[kRep], acc,
broadcastCtrl, kRep);
}
}
}
return acc;
}

int getNumSubmatrices(Type elementType, int mDim, int nDim) const {
if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64)
return 1;
Expand Down Expand Up @@ -187,13 +290,14 @@ struct DotOpMFMAConversionHelper {
llvm::report_fatal_error("No match found in MFMA database\n");

mfmaInsnName = (*maybeMfmaInsn).getInsnName();
unsigned k_base = (*maybeMfmaInsn).getKBase();
unsigned kBaseA = (*maybeMfmaInsn).getKBaseA();
unsigned kBaseB = (*maybeMfmaInsn).getKBaseB();

auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();

auto kWidth = aEncoding.getKWidth();
assert(kWidth == bEncoding.getKWidth());
auto kWidthA = aEncoding.getKWidth();
auto kWidthB = bEncoding.getKWidth();

auto repA = aEncoding.getMFMARep(aTensorTy.getShape());
auto repB = bEncoding.getMFMARep(bTensorTy.getShape());
Expand All @@ -209,9 +313,9 @@ struct DotOpMFMAConversionHelper {
auto numRepK = repA[1];

auto operandA = getValuesFromDotOperandLayoutStruct(
loadedA, numRepM, numRepK, kWidth, k_base, aTensorTy.getElementType());
loadedA, numRepM, numRepK, kWidthA, kBaseA, aTensorTy.getElementType());
auto operandB = getValuesFromDotOperandLayoutStruct(
loadedB, numRepN, numRepK, kWidth, k_base, aTensorTy.getElementType());
loadedB, numRepN, numRepK, kWidthB, kBaseB, aTensorTy.getElementType());

auto dstElemTy = dTensorTy.getElementType();
auto fc =
Expand All @@ -236,12 +340,10 @@ struct DotOpMFMAConversionHelper {

acc = zeroAuxiliarBlocks(subBlocks, acc);
for (size_t k = 0; k < numRepK; k++)
for (int kpack = 0; kpack < kWidth / k_base; ++kpack)
acc = mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInsnName, operandB[kpack][{n, k}],
operandA[kpack][{m, k}], acc)
: generateMFMAOp(mfmaInsnName, operandA[kpack][{m, k}],
operandB[kpack][{n, k}], acc);
for (int kpack = 0; kpack < kWidthA / kBaseA; ++kpack)
acc = generateMFMATile(mfmaInsnName, operandA[{kpack, m, k}],
operandB[{kpack, n, k}], acc, mDim, nDim,
mfmaLayout.getIsTransposed());
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {
fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] =
Expand All @@ -260,30 +362,39 @@ struct DotOpMFMAConversionHelper {
}

/**
* @brief extract vector from rawElems based on kWidth and k_base
* @brief extract vector from rawElems based on kWidth and kBase
* rawElems is a vector of kWidth elements. We need to prepare vector(s) of
* k_base elements for each mfma instruction
* kBase elements for each mfma instruction
*
* @param rawElems vector of "raw" elements for one mfma tile
* @param k id in k-pack
* @param kPack size of k-pack
* @param numIntrinsics number of operands we need to extract
* @param type type mfma intrinsic requires
*
* @return elements converted for one repetition
*/
SmallVector<Value> extractOperands(Value rawElems, int kWidth, int k_base,
Type type) const {
int kpack = kWidth / k_base;
SmallVector<Value> extractOperands(Value rawElems, int k, int kPack,
int numIntrinsics, Type type) const {
assert(numIntrinsics == 1 || numIntrinsics == 16);
auto rawTy = rawElems.getType().cast<VectorType>();
auto rawElemTy = rawTy.getElementType();
// number of elements required by one mfma intrinsic
int intrinsicK = rawTy.getNumElements() / numIntrinsics / kPack;
int kBase = rawTy.getNumElements() / kPack;

SmallVector<Value> results;
auto vecTy = vec_ty(type, k_base);
for (int k = 0; k < kpack; ++k) {
Value vec = undef(vecTy);
for (int elemId = 0; elemId < k_base; ++elemId) {
auto val =
extract_element(type, rawElems, i32_val(elemId + k * k_base));
vec = insert_element(vecTy, vec, val, i32_val(elemId));
// extract needed elements in original dtype
auto typedVecTy = vec_ty(rawElemTy, intrinsicK);
for (int intrinsic = 0; intrinsic < numIntrinsics; ++intrinsic) {
Value typedVec = undef(typedVecTy);
for (int elemId = 0; elemId < intrinsicK; ++elemId) {
int elemOff = elemId + intrinsic * intrinsicK + k * kBase;
auto val = extract_element(rawElemTy, rawElems, i32_val(elemOff));
typedVec = insert_element(typedVecTy, typedVec, val, i32_val(elemId));
}
if (type.getIntOrFloatBitWidth() == 8) {
if (4 == k_base)
// This is for int8 on pre- MI300 GPUs
results.push_back(bitcast(vec, i32_ty));
if (8 == k_base)
results.push_back(bitcast(vec, i64_ty));
} else
results.push_back(vec);
Value castedVec = bitcast(typedVec, type);
results.push_back(castedVec);
}
return results;
}
Expand All @@ -292,35 +403,38 @@ struct DotOpMFMAConversionHelper {
* @brief Converts dot operand structure to value table and converts types
* appropriate for mfma instructions
*/
SmallVector<ValueTable>
getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth,
int k_base, Type type) const {
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1,
int kWidth, int kBase,
Type type) const {
auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type);
ValueTable vals;
ValueTable vals1;
int kpack = kWidth / k_base;
SmallVector<ValueTable> dotOpVals(kpack);
int kpack = kWidth / kBase;
// "Wide operand" means that this operand is for mfma 4x64 layout
// This operand is 64x64 for fp16, bf16 and int8 data types and
// 16x64 for fp32
bool wideOperand = kWidth >= 16;
// How many rocdl intrinsics will process one tile
int numIntrinsics = wideOperand ? 16 : 1;
int intrinsicKWidth = wideOperand ? kBase / numIntrinsics : kBase;
Type intrinsicDType;
if (type.isF32())
intrinsicDType = f32_ty;
if (type.getIntOrFloatBitWidth() == 8)
intrinsicDType = rewriter.getIntegerType(intrinsicKWidth * 8);
if (type.isBF16())
intrinsicDType = vec_ty(i16_ty, intrinsicKWidth);
if (type.isF16())
intrinsicDType = vec_ty(f16_ty, intrinsicKWidth);
assert(intrinsicDType);

ValueTable dotOpVals;
for (int i = 0; i < n0; i++) {
for (int j = 0; j < n1; j++) {
auto rawElems = elems[n1 * i + j];

if (type.isF32()) {
for (int k = 0; k < kpack; ++k) {
dotOpVals[k][{i, j}] = extract_element(type, rawElems, i32_val(k));
}
} else {
SmallVector<Value> vals;
if (type.getIntOrFloatBitWidth() == 8) {
vals = extractOperands(rawElems, kWidth, k_base, i8_ty);
} else if (type.isBF16()) {
vals = extractOperands(rawElems, kWidth, k_base, i16_ty);
} else {
assert(type.isF16() && "Unsupported data type");
vals = extractOperands(rawElems, kWidth, k_base, f16_ty);
}
for (int k = 0; k < kpack; ++k) {
dotOpVals[k][{i, j}] = vals[k];
}
for (int k = 0; k < kpack; k++) {
SmallVector<Value> vals = extractOperands(
rawElems, k, kpack, numIntrinsics, intrinsicDType);
assert(vals.size() == numIntrinsics);
dotOpVals[{k, i, j}] = vals;
}
}
}
Expand Down
Loading

0 comments on commit fde46d8

Please # to comment.