Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bug Report] torch_to_numpy fails when tensors are not on cpu #1108

Closed
1 task done
mantasu opened this issue Jul 2, 2024 · 1 comment · Fixed by #1109
Closed
1 task done

[Bug Report] torch_to_numpy fails when tensors are not on cpu #1108

mantasu opened this issue Jul 2, 2024 · 1 comment · Fixed by #1109
Labels
bug Something isn't working

Comments

@mantasu
Copy link
Contributor

mantasu commented Jul 2, 2024

Describe the bug

torch_to_numpy fails when torch.Tensor is on a non-cpu device or has requires_grad set to True.

Note

Calling torch_to_jax does not cause these problems when jax is installed with GPU support - otherwise it also throws an error but informs the user that jax should be installed with GPU support.

Code example

import torch
from gymnasium.wrappers.numpy_to_torch import torch_to_numpy

t = torch.tensor(0.0, device="cuda")
torch_to_numpy(t) # throws error

System info

python=3.11, gymnasium=1.0.0a2

Additional context

Error traceback from the example script:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "anaconda/envs/sddl/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "anaconda/envs/sddl/lib/python3.11/site-packages/gymnasium/wrappers/numpy_to_torch.py", line 51, in _number_torch_to_numpy
    return np.array(value)
           ^^^^^^^^^^^^^^^
  File "anaconda/envs/sddl/lib/python3.11/site-packages/torch/_tensor.py", line 1087, in __array__
    return self.numpy()
           ^^^^^^^^^^^^
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Checklist

  • I have checked that there is no similar issue in the repo
@pseudo-rnd-thoughts
Copy link
Member

Thanks for the PR and issue again, looks all good

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
2 participants