diff --git a/requirements.txt b/requirements.txt index 6c5ede2..aea6e91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy>=1.14.0 scipy>=1.1.0 ase>=3.18.0 -jax>=0.2.3 -jaxlib>=0.1.56 \ No newline at end of file +jax>=0.4.20 +jaxlib>=0.4.20 diff --git a/sella/__init__.py b/sella/__init__.py index 9f73d4f..4e366d3 100644 --- a/sella/__init__.py +++ b/sella/__init__.py @@ -1,7 +1,8 @@ +import jax + from .optimize import IRC, Sella from .internal import Internals, Constraints -from jax.config import config -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) __all__ = ['IRC', 'Sella']