pyqg-jax: Quasigeostrophic Model in JAX

This is the documentation for the pyqg-jax package, a port of PyQG to JAX.

Porting the model to JAX makes it possible to run it on GPU, and apply JAX transformations including jax.jit() and jax.vmap(). This also makes it possible to integrate learned parameterizations into the model, or train online through the simulation using jax.grad() to take gradients.

That said, a note on the state of the port:

Warning

This is a partial, early stage port. There may be bugs and other numerical issues. The API may evolve as work continues.

Even so, we hope that the port will be useful. We have successfully made use of it in ongoing research projects, and hope that others can do so as well.

Indices