|
14 | 14 | #include "MCTargetDesc/NVPTXBaseInfo.h"
|
15 | 15 | #include "NVPTXUtilities.h"
|
16 | 16 | #include "llvm/Analysis/ValueTracking.h"
|
| 17 | +#include "llvm/CodeGen/ISDOpcodes.h" |
17 | 18 | #include "llvm/IR/GlobalValue.h"
|
18 | 19 | #include "llvm/IR/Instructions.h"
|
19 | 20 | #include "llvm/IR/IntrinsicsNVPTX.h"
|
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
|
829 | 830 | case MVT::v2f16:
|
830 | 831 | case MVT::v2bf16:
|
831 | 832 | case MVT::v2i16:
|
| 833 | + case MVT::v4i8: |
832 | 834 | return Opcode_i32;
|
833 | 835 | case MVT::f32:
|
834 | 836 | return Opcode_f32;
|
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
|
910 | 912 | // Vector Setting
|
911 | 913 | unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
|
912 | 914 | if (SimpleVT.isVector()) {
|
913 |
| - assert(Isv2x16VT(LoadedVT) && "Unexpected vector type"); |
| 915 | + assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) && |
| 916 | + "Unexpected vector type"); |
914 | 917 | // v2f16/v2bf16/v2i16 is loaded using ld.b32
|
915 | 918 | fromTypeWidth = 32;
|
916 | 919 | }
|
@@ -1254,19 +1257,23 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
|
1254 | 1257 | SDLoc DL(N);
|
1255 | 1258 | SDNode *LD;
|
1256 | 1259 | SDValue Base, Offset, Addr;
|
| 1260 | + EVT OrigType = N->getValueType(0); |
1257 | 1261 |
|
1258 | 1262 | EVT EltVT = Mem->getMemoryVT();
|
1259 | 1263 | unsigned NumElts = 1;
|
1260 | 1264 | if (EltVT.isVector()) {
|
1261 | 1265 | NumElts = EltVT.getVectorNumElements();
|
1262 | 1266 | EltVT = EltVT.getVectorElementType();
|
1263 | 1267 | // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
|
1264 |
| - if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) || |
1265 |
| - (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) || |
1266 |
| - (EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) { |
| 1268 | + if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) || |
| 1269 | + (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) || |
| 1270 | + (EltVT == MVT::i16 && OrigType == MVT::v2i16)) { |
1267 | 1271 | assert(NumElts % 2 == 0 && "Vector must have even number of elements");
|
1268 |
| - EltVT = N->getValueType(0); |
| 1272 | + EltVT = OrigType; |
1269 | 1273 | NumElts /= 2;
|
| 1274 | + } else if (OrigType == MVT::v4i8) { |
| 1275 | + EltVT = OrigType; |
| 1276 | + NumElts = 1; |
1270 | 1277 | }
|
1271 | 1278 | }
|
1272 | 1279 |
|
@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
|
1601 | 1608 | // concept of sign-/zero-extension, so emulate it here by adding an explicit
|
1602 | 1609 | // CVT instruction. Ptxas should clean up any redundancies here.
|
1603 | 1610 |
|
1604 |
| - EVT OrigType = N->getValueType(0); |
1605 | 1611 | LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
|
1606 | 1612 |
|
1607 | 1613 | if (OrigType != EltVT &&
|
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
|
1679 | 1685 | MVT ScalarVT = SimpleVT.getScalarType();
|
1680 | 1686 | unsigned toTypeWidth = ScalarVT.getSizeInBits();
|
1681 | 1687 | if (SimpleVT.isVector()) {
|
1682 |
| - assert(Isv2x16VT(StoreVT) && "Unexpected vector type"); |
| 1688 | + assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) && |
| 1689 | + "Unexpected vector type"); |
1683 | 1690 | // v2x16 is stored using st.b32
|
1684 | 1691 | toTypeWidth = 32;
|
1685 | 1692 | }
|
|
0 commit comments