pyqg-jax: Quasigeostrophic Model in JAX
This is the documentation for the pyqg-jax package, a port of PyQG to JAX.
Porting the model to JAX makes it possible to run it on GPU, and apply
JAX transformations including jax.jit()
and jax.vmap()
.
This also makes it possible to integrate learned
parameterizations into the model, or train
online through the simulation using
jax.grad()
to take gradients.
That said, a note on the state of the port:
Warning
This is a partial, early stage port. There may be bugs and other numerical issues. The API may evolve as work continues.
Even so, we hope that the port will be useful. We have successfully made use of it in ongoing research projects, and hope that others can do so as well.