Skip to content

Commit cbafb6f

Browse files
authored
[NVPTX] Improve lowering of v4i8 (#67866)
Make v4i8 a legal type and plumb through lowering of relevant instructions.
1 parent 67b675e commit cbafb6f

15 files changed

+1897
-540
lines changed

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
309309
const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
310310
O << Sym.getName();
311311
}
312+
313+
void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
314+
raw_ostream &O, const char *Modifier) {
315+
const MCOperand &MO = MI->getOperand(OpNum);
316+
int64_t Imm = MO.getImm();
317+
318+
switch (Imm) {
319+
default:
320+
return;
321+
case NVPTX::PTXPrmtMode::NONE:
322+
break;
323+
case NVPTX::PTXPrmtMode::F4E:
324+
O << ".f4e";
325+
break;
326+
case NVPTX::PTXPrmtMode::B4E:
327+
O << ".b4e";
328+
break;
329+
case NVPTX::PTXPrmtMode::RC8:
330+
O << ".rc8";
331+
break;
332+
case NVPTX::PTXPrmtMode::ECL:
333+
O << ".ecl";
334+
break;
335+
case NVPTX::PTXPrmtMode::ECR:
336+
O << ".ecr";
337+
break;
338+
case NVPTX::PTXPrmtMode::RC16:
339+
O << ".rc16";
340+
break;
341+
}
342+
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
4747
raw_ostream &O, const char *Modifier = nullptr);
4848
void printProtoIdent(const MCInst *MI, int OpNum,
4949
raw_ostream &O, const char *Modifier = nullptr);
50+
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
51+
const char *Modifier = nullptr);
5052
};
5153

5254
}

llvm/lib/Target/NVPTX/NVPTX.h

+12
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ enum CmpMode {
181181
FTZ_FLAG = 0x100
182182
};
183183
}
184+
185+
namespace PTXPrmtMode {
186+
enum PrmtMode {
187+
NONE,
188+
F4E,
189+
B4E,
190+
RC8,
191+
ECL,
192+
ECR,
193+
RC16,
194+
};
195+
}
184196
}
185197
void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
186198
} // namespace llvm

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

+14-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "MCTargetDesc/NVPTXBaseInfo.h"
1515
#include "NVPTXUtilities.h"
1616
#include "llvm/Analysis/ValueTracking.h"
17+
#include "llvm/CodeGen/ISDOpcodes.h"
1718
#include "llvm/IR/GlobalValue.h"
1819
#include "llvm/IR/Instructions.h"
1920
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
829830
case MVT::v2f16:
830831
case MVT::v2bf16:
831832
case MVT::v2i16:
833+
case MVT::v4i8:
832834
return Opcode_i32;
833835
case MVT::f32:
834836
return Opcode_f32;
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
910912
// Vector Setting
911913
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
912914
if (SimpleVT.isVector()) {
913-
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
915+
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
916+
"Unexpected vector type");
914917
// v2f16/v2bf16/v2i16 is loaded using ld.b32
915918
fromTypeWidth = 32;
916919
}
@@ -1254,19 +1257,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12541257
SDLoc DL(N);
12551258
SDNode *LD;
12561259
SDValue Base, Offset, Addr;
1260+
EVT OrigType = N->getValueType(0);
12571261

12581262
EVT EltVT = Mem->getMemoryVT();
12591263
unsigned NumElts = 1;
12601264
if (EltVT.isVector()) {
12611265
NumElts = EltVT.getVectorNumElements();
12621266
EltVT = EltVT.getVectorElementType();
12631267
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1264-
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
1265-
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
1266-
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
1268+
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
1269+
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
1270+
(EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
12671271
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1268-
EltVT = N->getValueType(0);
1272+
EltVT = OrigType;
12691273
NumElts /= 2;
1274+
} else if (OrigType == MVT::v4i8) {
1275+
EltVT = OrigType;
1276+
NumElts = 1;
12701277
}
12711278
}
12721279

@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
16011608
// concept of sign-/zero-extension, so emulate it here by adding an explicit
16021609
// CVT instruction. Ptxas should clean up any redundancies here.
16031610

1604-
EVT OrigType = N->getValueType(0);
16051611
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
16061612

16071613
if (OrigType != EltVT &&
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16791685
MVT ScalarVT = SimpleVT.getScalarType();
16801686
unsigned toTypeWidth = ScalarVT.getSizeInBits();
16811687
if (SimpleVT.isVector()) {
1682-
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
1688+
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
1689+
"Unexpected vector type");
16831690
// v2x16 is stored using st.b32
16841691
toTypeWidth = 32;
16851692
}

0 commit comments

Comments
 (0)