This is a partial port of PyQG to JAX which enables GPU acceleration, batching, automatic differentiation, etc.
- Documentation: https://pyqg-jax.readthedocs.io/en/latest/
- Source Code: https://github.com/karlotness/pyqg-jax
- Bug Reports: https://github.com/karlotness/pyqg-jax/issues
Install from PyPI using pip:
$ python -m pip install pyqg-jax
or from conda-forge:
$ conda install -c conda-forge pyqg-jax
This should install required dependencies, but JAX itself may require special attention, particularly for GPU support. Follow the JAX installation instructions.
Documentation is a work in progress. The parameters QGModel
implemented here are the same as for the model in the original PyQG,
so consult the pyqg
documentation for details.
However, there are a few overarching changes used to make the models JAX-compatible:
-
The model state is now a separate, immutable object rather than being attributes of the
QGModel
class -
Time-stepping is now separated from the models. Use
steppers.AB3Stepper
for the same time stepping as in the originalQGModel
. -
Random initialization requires an explicit
key
variable as with all JAX random number generation.
The QGModel
uses double precision (float64
) values for part of its
computation regardless of the precision setting. Make sure JAX is set
to enable 64-bit. See the
documentation
for details. One option is to set the following environment variable:
export JAX_ENABLE_X64=True
or use the %env
magic
in a Jupyter notebook.
A short example initializing a QGModel
, adding a parameterization,
and taking a single step (for more, see the
examples in
the documentation).
>>> import pyqg_jax
>>> import jax
>>> # Construct model, parameterization, and time-stepper
>>> stepped_model = pyqg_jax.steppers.SteppedModel(
... model=pyqg_jax.parameterizations.smagorinsky.apply_parameterization(
... pyqg_jax.qg_model.QGModel(),
... constant=0.08,
... ),
... stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
... )
>>> # Initialize the model state (wrapped in stepper and parameterization state)
>>> stepper_state = stepped_model.create_initial_state(
... jax.random.key(0)
... )
>>> # Compute next state
>>> next_stepper_state = stepped_model.step_model(stepper_state)
>>> # Unwrap the result from the stepper and parameterization
>>> next_param_state = next_stepper_state.state
>>> next_model_state = next_param_state.model_state
>>> final_q = next_model_state.q
For repeated time-stepping combine step_model
with
jax.lax.scan
.
This software is distributed under the MIT license. See LICENSE.txt for the license text.