Skip to content

Commit 2c8c86b

Browse files
hongxiayangxuhancn
authored andcommitted
fix for launching kernel invalid config error when calling embedding … (pytorch#130994)
…with large index Fixes pytorch#130806 When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error: RuntimeError: HIP error: invalid configuration argument What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}. Found two issues in the Indexing.cu: 1: ptrdiff_t was used but it is signed int, outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367): 2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648 As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error. [Test] Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now: ``` output=tensor([[ 0.4044, -0.0244, -0.6865, ..., -0.7800, 0.1175, 1.6726], [-1.0866, -0.1609, 0.3538, ..., 1.9105, 0.7882, 1.1583], [-2.2079, 0.3736, 0.3610, ..., -0.2658, -0.0459, 1.3077], ..., [ 0.8753, -0.7482, -0.1978, ..., 0.9016, 1.1501, -0.5178], [-1.5845, -0.6277, 1.4520, ..., 0.5733, -2.1198, -0.0915], [-0.6310, -1.0239, -0.1910, ..., 0.4309, 0.1630, 0.3239]], device='cuda:0'), dim=2, numel=2147483648 ``` Added a large tensor unit test too. ``` /pytorch# pytest test/nn/test_embedding.py -k test_large_tensors ================================================================================== test session starts =================================================================================== platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0 rootdir: /dockerx/development/pytorch configfile: pytest.ini plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1 collected 288 items / 287 deselected / 1 selected Running 1 items in this shard test/nn/test_embedding.py . [100%] =========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================ ``` Pull Request resolved: pytorch#130994 Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell
1 parent 62e05bf commit 2c8c86b

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_
688688

689689

690690
// Check tensor dimensions for index operations, and return the slice size.
691-
static ptrdiff_t getSliceSize(const Tensor & dst,
691+
static size_t getSliceSize(const Tensor & dst,
692692
int dim,
693693
const Tensor & index,
694694
const Tensor & src)
@@ -698,7 +698,7 @@ static ptrdiff_t getSliceSize(const Tensor & dst,
698698

699699
TORCH_CHECK(index.dim() <= 1, "Index must be vector or scalar");
700700

701-
ptrdiff_t dstSliceSize = 1;
701+
size_t dstSliceSize = 1;
702702
TORCH_CHECK(dim >= 0 && dim < dstDims, "Indexing dim ", dim, " is out of bounds");
703703
for (const auto d: c10::irange(dstDims)) {
704704
if (d != dim) {
@@ -710,7 +710,7 @@ static ptrdiff_t getSliceSize(const Tensor & dst,
710710
TORCH_CHECK(index.numel() == src.size(dim),
711711
"length of src.size[dim] is not equal to length of indices");
712712

713-
ptrdiff_t srcSliceSize = 1;
713+
size_t srcSliceSize = 1;
714714
bool mismatch = false;
715715

716716
if (dstDims != srcDims) mismatch = true;
@@ -900,11 +900,11 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c
900900
// total size of the tensor ignoring dimension `dim`;
901901
// -the number of index we are choosing, which is the total size
902902
// of the tensor `index`.
903-
const ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_);
904-
const ptrdiff_t sourceTotalSize = source.numel();
905-
const int64_t selfAddDimSize = self_.size(dim);
906-
const ptrdiff_t numIndex = index.numel();
907-
const int64_t selfNumel = self_.numel();
903+
const uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
904+
const uint64_t sourceTotalSize = source.numel();
905+
const uint64_t selfAddDimSize = self_.size(dim);
906+
const uint64_t numIndex = index.numel();
907+
const uint64_t selfNumel = self_.numel();
908908

909909
if (sliceSize == 0) {
910910
return;
@@ -933,11 +933,11 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c
933933
selfAddDimSize, selfNumel, reduce_add, alpha_value); \
934934
C10_CUDA_KERNEL_LAUNCH_CHECK();
935935

936-
const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
937-
const dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
936+
const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
937+
const dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
938938

939-
const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
940-
const dim3 largeIndexBlock(std::min(sourceTotalSize, (ptrdiff_t)128));
939+
const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
940+
const dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
941941

942942
if (cuda::detail::canUse32BitIndexMath(result) &&
943943
cuda::detail::canUse32BitIndexMath(source) &&
@@ -1073,11 +1073,11 @@ void index_reduce_func_cuda_impl(
10731073
// total size of the tensor ignoring dimension `dim`;
10741074
// -the number of index we are choosing, which is the total size
10751075
// of the tensor `index`.
1076-
ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_);
1077-
ptrdiff_t sourceTotalSize = source.numel();
1078-
int64_t selfReduceDimSize = self_.size(dim);
1079-
ptrdiff_t numIndex = index.numel();
1080-
int64_t selfNumel = self_.numel();
1076+
uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
1077+
uint64_t sourceTotalSize = source.numel();
1078+
uint64_t selfReduceDimSize = self_.size(dim);
1079+
uint64_t numIndex = index.numel();
1080+
uint64_t selfNumel = self_.numel();
10811081

10821082
if (sliceSize == 0) {
10831083
return;
@@ -1106,11 +1106,11 @@ void index_reduce_func_cuda_impl(
11061106
selfReduceDimSize, selfNumel, reduce_func, alpha_value); \
11071107
C10_CUDA_KERNEL_LAUNCH_CHECK();
11081108

1109-
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
1110-
dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
1109+
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
1110+
dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
11111111

1112-
dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
1113-
dim3 largeIndexBlock(std::min(sourceTotalSize, (ptrdiff_t)128));
1112+
dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
1113+
dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
11141114

11151115
if (cuda::detail::canUse32BitIndexMath(result) &&
11161116
cuda::detail::canUse32BitIndexMath(source) &&
@@ -1342,8 +1342,8 @@ void index_select_out_cuda_impl(
13421342
const Tensor& self,
13431343
long dim,
13441344
const Tensor& index) {
1345-
ptrdiff_t numIndices = index.numel();
1346-
int selfDims = self.dim() == 0 ? 1 : self.dim();
1345+
uint64_t numIndices = index.numel();
1346+
uint64_t selfDims = self.dim() == 0 ? 1 : self.dim();
13471347

13481348
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
13491349

@@ -1364,7 +1364,7 @@ void index_select_out_cuda_impl(
13641364
at::native::resize_output(out, newSize);
13651365
}
13661366

1367-
ptrdiff_t outTotalSize = out.numel();
1367+
uint64_t outTotalSize = out.numel();
13681368
if (outTotalSize == 0) {
13691369
return;
13701370
}
@@ -1376,8 +1376,8 @@ void index_select_out_cuda_impl(
13761376
// total size of the tensor ignoring dimension `dim`;
13771377
// -the number of indices we are choosing, which is the total size
13781378
// of the tensor `indices`.
1379-
int64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim);
1380-
ptrdiff_t sliceSize = outTotalSize / numIndices;
1379+
uint64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim);
1380+
uint64_t sliceSize = outTotalSize / numIndices;
13811381

13821382
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
13831383

@@ -1400,11 +1400,14 @@ void index_select_out_cuda_impl(
14001400
selfSelectDimSize); \
14011401
C10_CUDA_KERNEL_LAUNCH_CHECK();
14021402

1403-
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
1404-
dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
1403+
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t) (mpc * 8)));
1404+
dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
14051405

1406-
dim3 largeIndexGrid(std::min(ceil_div(outTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
1407-
dim3 largeIndexBlock(std::min(outTotalSize, (ptrdiff_t)128));
1406+
dim3 largeIndexGrid(std::min(ceil_div(outTotalSize, (uint64_t)128), (uint64_t) (mpc * 8)));
1407+
// for issue https://github.com/pytorch/pytorch/issues/130806 there are two problems
1408+
// 1: ptrdiff_t was used but it is signed int, outTotalSize of 2147483648 can cause overflow
1409+
// 2: On ROCm, std::min -> ::min did not work as expected on when outTotalSize>=2147483648
1410+
dim3 largeIndexBlock( (outTotalSize < 128) ? outTotalSize : 128 );
14081411
if (cuda::detail::canUse32BitIndexMath(out) &&
14091412
cuda::detail::canUse32BitIndexMath(self) &&
14101413
cuda::detail::canUse32BitIndexMath(index)) {

test/nn/test_embedding.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
dtypes,
1313
dtypesIfCUDA,
1414
instantiate_device_type_tests,
15+
largeTensorTest,
1516
onlyCUDA,
1617
onlyNativeDeviceTypes,
1718
skipCUDAIf,
1819
skipMeta,
1920
TEST_WITH_ROCM,
2021
)
22+
2123
from torch.testing._internal.common_nn import NNTestCase
2224
from torch.testing._internal.common_utils import (
2325
_assertGradAndGradgradChecks,
@@ -180,6 +182,15 @@ def test_embedding_functional(self):
180182

181183
self.assertEqual(res_old, res_F)
182184

185+
# https://github.com/pytorch/pytorch/issues/130806
186+
@largeTensorTest("40GB", device="cuda")
187+
def test_large_tensors(self):
188+
input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
189+
w = torch.randn([16032, 16384], device="cuda")
190+
out = torch.nn.functional.embedding(input, w)
191+
self.assertEqual(out.dim(), 2)
192+
self.assertEqual(out.numel(), 2147483648)
193+
183194
def test_embedding_bag_functional(self):
184195
a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
185196
embeddings = torch.rand(4, 3, requires_grad=True)

0 commit comments

Comments
 (0)