States
Base model state objects.
In pyqg-jax model states are separated into immutable objects to more closely match the rest of JAX.
Models manipulate instances of PseudoSpectralState which
stores only the q component from which other model variables are
derived.
To access these variables, the states can be expanded into a
FullPseudoSpectralState with the remaining attributes.
Note
When time-stepping in JAX (in particular with
jax.lax.scan()), the state objects will have leading time
dimensions in addition the their normal spatial dimensions. JAX
will store these in “structure-of-arrays” style.
To slice into states, consider combining
jax.tree.map() with a lambda or a combination
of operator.itemgetter() and slice.
- class pyqg_jax.state.PseudoSpectralState(*, qh, _q_shape)[source]
Core state evolved by a model instance.
This is the innermost state type evolved by the models. This state can be expanded into a
FullPseudoSpectralStateby calling methods on the models.Warning
You should not construct this class yourself. Instead, you should retrieve instances from a model and if desired
update()their attributes.- update(**kwargs)[source]
Replace the value stored in this state.
This function produces a new state object, containing the replacement value.
The keyword arguments may be either q or qh (not both), allowing the replacement value to be provided in spectral form if desired.
The object this method is called on is not modified.
- Parameters:
- Returns:
A copy of this object with the specified values replaced.
- Return type:
- Raises:
ValueError – If the shape of the replacement does not match the existing shape or if duplicate updates are specified.
TypeError – If the dtype of the replacement does not match the existing type.
- class pyqg_jax.state.FullPseudoSpectralState(*, state, ph, u, v, dqhdt)[source]
Full state including calculated values expanded by a model.
This is an expanded form of
PseudoSpectralStatewhich includes additional attributes calculated as part of running one of the models.Warning
You should not construct this class yourself. Instead, you should retrieve instances from a model and if desired
update()their attributes.- dqhdt
Spectral derivative with respect to time for
qh.This value is the update applied to the model when time stepping.
- Type:
Notes
Changed in version 0.7.0: Removed attributes uq, vq, uqh, and vqh.
- update(**kwargs)[source]
Replace values stored in this state.
This function produces a new state object, with specified attributes replaced.
The keyword arguments may specify any of this class’s attributes except
state, but must not apply multiple updates to the same attribute. That is, modifying both the spectral and real space values at the same time is not allowed.The object this method is called on is not modified.
- Parameters:
- Returns:
A copy of this object with the specified values replaced.
- Return type:
- Raises:
ValueError – If the shape of the replacement does not match the existing shape or if duplicate updates are specified.
TypeError – If the dtype of the replacement does not match the existing type.
Model instances can have their data-type precision selected as a
constructor argument. The enumeration Precision gives
available options.
- class pyqg_jax.state.Precision(*values)[source]
Enumeration for model precision levels.
When constructing a base model, use values of this enumeration to select the numerical precision which should be used for the states and internal calculations.
Double precision may be significantly slower, for example on GPUs.
Members of this enum have attributes
dtype_realanddtype_complexproviding thedtypeobjects used at each precision level.- SINGLE
Single precision.
Models will use
jax.numpy.float32andjax.numpy.complex64.
- DOUBLE
Double precision
Models will use
jax.numpy.float64andjax.numpy.complex128.Ensure that JAX has 64-bit precision enabled.
Notes
Changed in version 0.9.0: Added
dtype_realanddtype_complexattributes.
Models also expose information about the grid on which values are computed.
- class pyqg_jax.state.Grid(*, nz, ny, nx, L, W, Hi)[source]
Information on the spatial grid used by a model.
The models in this package use an Arakawa A-grid for real space grids. This class also provides information on the shapes of arrays storing real and spectral values and the distances along each grid edge.
Added in version 0.8.0.
Warning
You should not construct this class yourself. Instead, you should retrieve instances from a model.
- H[source]
Domain length in the z direction.
In most cases this may actually be a JAX float scalar or tracer.
- Type:
- Hi
The length of each layer in the z direction.
This is a vector of length
nzand whose entries sum toH.- Type:
- get_kappa(dtype=Precision.SINGLE)[source]
Information on the wavenumber at each spectral grid point.
- Parameters:
precision (Precision, optional) – Precision for the wavenumber calculations.
- Returns:
A two-dimensional grid of wavenumber values at the specified precision.
These values have the shape of a spectral state (see
spectral_state_shape) without the leadingnzdimension.- Return type: