Skip to content

[rocm] F.embedding reports invalid configuration argument #130806

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
xw285cornell opened this issue Jul 16, 2024 · 4 comments
Closed

[rocm] F.embedding reports invalid configuration argument #130806

xw285cornell opened this issue Jul 16, 2024 · 4 comments
Assignees
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@xw285cornell
Copy link
Contributor

xw285cornell commented Jul 16, 2024

🐛 Describe the bug

F.embedding will crash with relatively large tensor input on the AMD GPU:

input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
w = torch.randn([16032, 16384], device="cuda")
torch.nn.functional.embedding(input, w)

RuntimeError: HIP error: invalid configuration argument

Versions

top of tree

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Jul 16, 2024
@xw285cornell
Copy link
Contributor Author

cc. @jeffdaily

@malfet malfet added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 16, 2024
@hongxiayang
Copy link
Collaborator

@xw285cornell Reproduced. Thanks for reporting.

@jeffdaily : Here is the problem: The number 2147483648 looks quite crazy in hipLuanchKernel.

:3:hip_module.cpp           :669 : 12476033797857 us: [pid:79    tid:0x7f8211ea2180] ^[[32m hipLaunchKernel ( 0x7f81f8206740, {832,1,1}, {2147483648,1,1}, 0x7ffd10647410, 0, stream:<null> ) ^[[0m

@hongxiayang hongxiayang moved this to Todo in PyTorch on ROCm Jul 16, 2024
@hongxiayang
Copy link
Collaborator

will put out a PR to fix this soon.

@hongxiayang hongxiayang moved this from Todo to In Progress in PyTorch on ROCm Jul 19, 2024
@hongxiayang hongxiayang self-assigned this Jul 19, 2024
@github-project-automation github-project-automation bot moved this from In Progress to Done in PyTorch on ROCm Jul 20, 2024
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this issue Jul 22, 2024
pytorch#130994)

…with large index

Fixes pytorch#130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: pytorch#130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell
xuhancn pushed a commit to xuhancn/pytorch that referenced this issue Jul 25, 2024
pytorch#130994)

…with large index

Fixes pytorch#130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: pytorch#130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell
pytorchbot pushed a commit that referenced this issue Aug 13, 2024
#130994)

…with large index

Fixes #130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: #130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell

(cherry picked from commit 637ab85)
atalman pushed a commit that referenced this issue Aug 14, 2024
#133346)

fix for launching kernel invalid config error when calling embedding … (#130994)

…with large index

Fixes #130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: #130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell

(cherry picked from commit 637ab85)

Co-authored-by: hongxyan <hongxyan@amd.com>
pruthvistony pushed a commit to ROCm/pytorch that referenced this issue Aug 15, 2024
pytorch#133346)

fix for launching kernel invalid config error when calling embedding … (pytorch#130994)

…with large index

Fixes pytorch#130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: pytorch#130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell

(cherry picked from commit 637ab85)

Co-authored-by: hongxyan <hongxyan@amd.com>
@atalman atalman added this to the 2.4.1 milestone Aug 29, 2024
@jithunnair-amd
Copy link
Collaborator

jithunnair-amd commented Sep 3, 2024

Verified with PyTorch2.4.1 final RC wheels on MI100/MI210:
pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/test/rocm6.1

>>> input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
>>> torch.nn.functional.embedding(input, w)
tensor([[-1.4506,  0.4196, -0.8153,  ..., -0.3261,  0.1768, -1.3584],
        [ 0.7569, -0.1229, -1.6053,  ..., -0.0411,  0.4474, -1.1425],
        [-0.1418, -1.1572,  1.2050,  ...,  0.1415, -0.2598,  0.5112],
        ...,
        [-0.2178, -0.3899, -0.5906,  ...,  0.3875,  0.5836,  0.0545],
        [-0.4631,  0.0565,  1.9452,  ..., -1.0769, -0.0736, -0.7225],
        [ 0.4030,  0.6888,  0.2773,  ...,  0.1789,  2.0047,  0.7733]],
       device='cuda:0')

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

5 participants