diff --git a/autoray/autoray.py b/autoray/autoray.py index dbb5617..cec7751 100644 --- a/autoray/autoray.py +++ b/autoray/autoray.py @@ -1495,6 +1495,7 @@ def jax_to_numpy(x): _BACKEND_ALIASES["jaxlib"] = "jax" +_MODULE_ALIASES["jax.scipy"] = "jax.scipy" _MODULE_ALIASES["jax"] = "jax.numpy" _SUBMODULE_ALIASES["jax", "complex"] = "jax.lax" _SUBMODULE_ALIASES["jax", "linalg.expm"] = "jax.scipy.linalg"