States

Base model state objects.

In pyqg-jax model states are separated into immutable objects to more closely match the rest of JAX.

Models manipulate instances of PseudoSpectralState which stores only the q component from which other model variables are derived.

To access these variables, the states can be expanded into a FullPseudoSpectralState with the remaining attributes.

Note

When time-stepping in JAX (in particular with jax.lax.scan()), the state objects will have leading time dimensions in addition the their normal spatial dimensions. JAX will store these in “structure-of-arrays” style.

To slice into states, consider combining jax.tree.map() with a lambda or a combination of operator.itemgetter() and slice.

class pyqg_jax.state.PseudoSpectralState(*, qh, _q_shape)[source]

Core state evolved by a model instance.

This is the innermost state type evolved by the models. This state can be expanded into a FullPseudoSpectralState by calling methods on the models.

Warning

You should not construct this class yourself. Instead, you should retrieve instances from a model and if desired update() their attributes.

q[source]

Potential vorticity in real space.

This entry has shape (nz, ny, nx)

Type:

jax.Array

qh

Potential vorticity in spectral space.

This is the spectral form of q.

Type:

jax.Array

update(**kwargs)[source]

Replace the value stored in this state.

This function produces a new state object, containing the replacement value.

The keyword arguments may be either q or qh (not both), allowing the replacement value to be provided in spectral form if desired.

The object this method is called on is not modified.

Parameters:
Returns:

A copy of this object with the specified values replaced.

Return type:

PseudoSpectralState

Raises:
  • ValueError – If the shape of the replacement does not match the existing shape or if duplicate updates are specified.

  • TypeError – If the dtype of the replacement does not match the existing type.

class pyqg_jax.state.FullPseudoSpectralState(*, state, ph, u, v, dqhdt)[source]

Full state including calculated values expanded by a model.

This is an expanded form of PseudoSpectralState which includes additional attributes calculated as part of running one of the models.

Warning

You should not construct this class yourself. Instead, you should retrieve instances from a model and if desired update() their attributes.

state

Inner, partial state providing values for q and qh.

Type:

PseudoSpectralState

q[source]

Potential vorticity in real space.

Pass-through accessor for state.q

Type:

jax.Array

qh[source]

Potential vorticity in spectral space.

Pass-through accessor for state.qh

Type:

jax.Array

p[source]

Streamfunction in real space.

Type:

jax.Array

ph

Streamfunction in spectral space.

Type:

jax.Array

u

Zonal velocity anomaly in real space.

Type:

jax.Array

uh[source]

Zonal velocity anomaly in spectral space.

Type:

jax.Array

v

Meridional velocity anomaly in real space.

Type:

jax.Array

vh[source]

Meridional velocity anomaly in spectral space.

Type:

jax.Array

dqhdt

Spectral derivative with respect to time for qh.

This value is the update applied to the model when time stepping.

Type:

jax.Array

dqdt[source]

Real space version of dqhdt.

Type:

jax.Array

Notes

Changed in version 0.7.0: Removed attributes uq, vq, uqh, and vqh.

update(**kwargs)[source]

Replace values stored in this state.

This function produces a new state object, with specified attributes replaced.

The keyword arguments may specify any of this class’s attributes except state, but must not apply multiple updates to the same attribute. That is, modifying both the spectral and real space values at the same time is not allowed.

The object this method is called on is not modified.

Parameters:
Returns:

A copy of this object with the specified values replaced.

Return type:

FullPseudoSpectralState

Raises:
  • ValueError – If the shape of the replacement does not match the existing shape or if duplicate updates are specified.

  • TypeError – If the dtype of the replacement does not match the existing type.

Model instances can have their data-type precision selected as a constructor argument. The enumeration Precision gives available options.

class pyqg_jax.state.Precision(*values)[source]

Enumeration for model precision levels.

When constructing a base model, use values of this enumeration to select the numerical precision which should be used for the states and internal calculations.

Double precision may be significantly slower, for example on GPUs.

Members of this enum have attributes dtype_real and dtype_complex providing the dtype objects used at each precision level.

SINGLE

Single precision.

Models will use jax.numpy.float32 and jax.numpy.complex64.

DOUBLE

Double precision

Models will use jax.numpy.float64 and jax.numpy.complex128.

Ensure that JAX has 64-bit precision enabled.

Notes

Changed in version 0.9.0: Added dtype_real and dtype_complex attributes.

Models also expose information about the grid on which values are computed.

class pyqg_jax.state.Grid(*, nz, ny, nx, L, W, Hi)[source]

Information on the spatial grid used by a model.

The models in this package use an Arakawa A-grid for real space grids. This class also provides information on the shapes of arrays storing real and spectral values and the distances along each grid edge.

Added in version 0.8.0.

Warning

You should not construct this class yourself. Instead, you should retrieve instances from a model.

real_state_shape[source]

Tuple specifying the shape of arrays for real space variables.

Type:

tuple[int, int, int]

spectral_state_shape[source]

Tuple specifying the shape of arrays for spectral variables.

Type:

tuple[int, int, int]

nx

Number of grid points in the x direction.

Type:

int

ny

Number of grid points in the y direction.

Type:

int

nz

Number of grid points in the z direction.

Type:

int

L

Domain length in the x direction.

Type:

float

W

Domain length in the y direction.

Type:

float

H[source]

Domain length in the z direction.

In most cases this may actually be a JAX float scalar or tracer.

Type:

float

Hi

The length of each layer in the z direction.

This is a vector of length nz and whose entries sum to H.

Type:

jax.Array

nk[source]

Number of spectral grid points in the k direction.

Type:

int

nl[source]

Number of spectral grid points in the l direction.

Type:

int

dx[source]

Space between grid points in the x direction.

Type:

float

dy[source]

Space between grid points in the y direction.

Type:

float

dk[source]

Spectral spacing in the k direction.

Type:

float

dl[source]

Spectral spacing in the l direction.

Type:

float

get_kappa(dtype=Precision.SINGLE)[source]

Information on the wavenumber at each spectral grid point.

Parameters:

precision (Precision, optional) – Precision for the wavenumber calculations.

Returns:

A two-dimensional grid of wavenumber values at the specified precision.

These values have the shape of a spectral state (see spectral_state_shape) without the leading nz dimension.

Return type:

jax.Array