Parameterizations
Utilities for adding parameterizations to core models.
This subpackage includes some of the pre-implemented parameterizations from PyQG as well as utilities for defining custom parameterizations and applying them to a model.
Existing Parameterizations
Implementing Parameterizations
We include some utilities for applying custom parameterizations to a core model. User-defined parameterizations are implemented as a pair of functions one computing the parameterized updates for the model, and the other initializing an auxiliary state.
A sample param_func could be:
def param_func(model_state, param_aux, model):
# Access attributes and methods of the inner model
updates = model.get_updates(model_state)
# Compute updates and new auxiliary values
new_updates = do_calculations(model_state, updates)
new_param_aux = compute_new_aux(param_aux)
# Return model state updates, and new aux values
return new_updates, new_param_aux
and a sample init_param_aux_func:
def init_param_aux_func(model_state, model, *args, **kwargs):
# Initialize the aux state based on inner model_state
# and additional arguments
return make_new_aux_state(model_state)
These two work together. The auxiliary state is an arbitrary object
(must be a JAX PyTree). Simple choices are tuples of JAX values
(arrays, key,
etc.) or immutable python objects (str,
bool, etc.). The auxiliary state can be
None if no values are necessary.
Your param_func is responsible for updating the auxiliary state as
needed. ParameterizedModel will wrap the auxiliary state in a
NoStepValue so the
time-steppers will not manipulate it.
The additional state is provided to allow propagating extra non-time-stepped value forward when stepping the model. Some possibilities:
Stochastic parameterizations will need to include and
splitakeyto use randomness.Stateful parameterizations could maintain a history of previous model states.
Stateless, deterministic parameterizations can use
Noneas their auxiliary state.
Once you have implemented your parameterization, apply it to a base
model using ParameterizedModel, which can then be used with a
time stepper.
- class pyqg_jax.parameterizations.ParameterizedModel(model, param_func, init_param_aux_func=None)[source]
A model wrapped in a user-specified parameterization.
- Parameters:
model – The inner core model to wrap in the parameterization.
param_func (function) – The function implementing the parameterization. Will be called by
get_updates()to compute time-stepping updates.init_param_aux_func (function, optional) – The function used to initialize the parameterization’s auxiliary state. Defaults to a function initializing the state to
None.
- model
The inner model wrapped in the parameterization.
- param_func
The user-specified parameterization function. Takes arguments
(model_state, param_aux, model).- Type:
function
- init_param_aux_func
Function used to initialize the auxiliary state. Takes arguments
(model_state, model).- Type:
function
- get_full_state(state)[source]
Expand a wrapped partial state into an unwrapped full state.
This function defers to
modelto compute the full state.- Parameters:
state (ParameterizedModelState) – The wrapped, parameterized state to be expanded.
- Returns:
The expanded state. The real type depends on
model, but is likely to beFullPseudoSpectralState.- Return type:
- get_updates(state)[source]
Get updates for time-stepping state.
state is a wrapped, partial
modelstate. This function returns updates for time-stepping.This function makes use of
param_func, applying the parameterization to the updates.- Parameters:
state (ParameterizedModelState) – The state which will be time stepped using the computed updates.
- Returns:
A new state object where each field corresponds to a time-stepping update to be applied.
- Return type:
Note
The object returned by this function has the same type of state, but contains updates. This is so the time-stepping can be done by mapping over the states and updates as JAX pytrees with the same structure.
- postprocess_state(state)[source]
Apply fixed filtering to state.
This function should be called once on each new state after each time step.
SteppedModelhandles this internally.This function defers to
modelfor the post-processing.- Parameters:
state (ParameterizedModelState) – The wrapped state to be filtered.
- Returns:
The wrapped filtered state.
- Return type:
- create_initial_state(key, *args, **kwargs)[source]
Create a new wrapped initial state with random initialization.
This function defers to
modelto initialize the inner state and makes use ofinit_param_aux_functo initialize the parameterization’s auxiliary state.- Parameters:
key (jax.random.key) – The PRNG state used as the random key for initialization.
*args – Arbitrary additional arguments for
init_param_aux_func**kwargs – Arbitrary additional arguments for
init_param_aux_func
- Returns:
The new wrapped state with random initialization.
- Return type:
- initialize_param_state(state, *args, **kwargs)[source]
Wrap an existing state from
modelin aParameterizedModelState.This function takes an existing inner model state and wraps it so that it can be used with the parameterized model.
This function uses of
init_param_aux_functo initialize the parameterization’s auxiliary state.- Parameters:
state – The inner model state to wrap. The type depends on
modelbut is likely to bePseudoSpectralState.*args – Arbitrary additional arguments for
init_param_aux_func**kwargs – Arbitrary additional arguments for
init_param_aux_func
- Returns:
A wrapped copy of state.
- Return type:
- class pyqg_jax.parameterizations.ParameterizedModelState(*, model_state, param_aux)[source]
Wrapped model state for parameterized models.
Warning
You should not construct this class yourself. Instead, you should obtain instances from
ParameterizedModel.- model_state
The inner model state. The actual types depends on the inner model, but this is likely to be
FullPseudoSpectralState.- Type:
- param_aux
The auxiliary state for the parameterization. This is an arbitrary object time-stepped by the parameterization itself. It will be wrapped in a
NoStepValueto shield it from the time steppers.- Type:
We also provide decorators which can simplify the process of implementing common parameterizations in terms of velocity or potential vorticity.
- @pyqg_jax.parameterizations.uv_parameterization[source]
Decorator implementing parameterizations in terms of velocity.
The target function should take as its first three arguments
(state, param_aux, model)as with any other parameterization function. Additional arguments will be passed through unmodified.This function should then return two values:
(du, dv), param_aux. These values will then be added to the model’s original update value to form the parameterized update.The wrapped function is suitable for use with
ParameterizedModel.
- @pyqg_jax.parameterizations.q_parameterization[source]
Decorator implementing parameterizations in terms of potential vorticity.
The target function should take as its first three arguments
(state, param_aux, model)as with any other parameterization function. Additional arguments will be passed through unmodified.This function should then return two values:
dq, param_aux. These values will then be added to the model’s original update value to form the parameterized update.The wrapped function is suitable for use with
ParameterizedModel.See also:
pyqg.parameterizations.QParameterization