Replies: 2 comments 2 replies
-
PyMC should not have to do anything related to this specific backend. Is anything special needed on the JAX side for the GPU to be used? If not, it should work out of the box. |
Beta Was this translation helpful? Give feedback.
0 replies
-
Thanks Ricardo. Whenever I use Jax directly it does detect the M1 GPU. However, the GPU is not detected when I use pm.sampling_jax. Is there a configuration that I need to set for that? Is it perhaps on aesara side? |
Beta Was this translation helpful? Give feedback.
2 replies
# for free
to join this conversation on GitHub.
Already have an account?
# to comment
-
Was wondering if there is a chance to use Apple M1 GPU in pymc4 since Jax is used in pymc4 (though sampling_jax) and jax now recognizes Apple M1 GPU? PyTorch also recognizes now Apple M1 GPU, is it true for numpyro? and if that is true, can this be transferred to pymc4? thanks.
Beta Was this translation helpful? Give feedback.
All reactions