Parameterizations

Utilities for adding parameterizations to core models.

This subpackage includes some of the pre-implemented parameterizations from PyQG as well as utilities for defining custom parameterizations and applying them to a model.

Implementing Parameterizations

We include some utilities for applying custom parameterizations to a core model. User-defined parameterizations are implemented as a pair of functions one computing the parameterized updates for the model, and the other initializing an auxiliary state.

A sample param_func could be:

def param_func(model_state, param_aux, model):
    # Access attributes and methods of the inner model
    updates = model.get_updates(model_state)
    # Compute updates and new auxiliary values
    new_updates = do_calculations(model_state, updates)
    new_param_aux = compute_new_aux(param_aux)
    # Return model state updates, and new aux values
    return new_updates, new_param_aux

and a sample init_param_aux_func:

def init_param_aux_func(model_state, model, *args, **kwargs):
    # Initialize the aux state based on inner model_state
    # and additional arguments
    return make_new_aux_state(model_state)

These two work together. The auxiliary state is an arbitrary object (must be a JAX PyTree). Simple choices are tuples of JAX values (arrays, key, etc.) or immutable python objects (str, bool, etc.). The auxiliary state can be None if no values are necessary.

Your param_func is responsible for updating the auxiliary state as needed. ParameterizedModel will wrap the auxiliary state in a NoStepValue so the time-steppers will not manipulate it.

The additional state is provided to allow propagating extra non-time-stepped value forward when stepping the model. Some possibilities:

  • Stochastic parameterizations will need to include and split a key to use randomness.

  • Stateful parameterizations could maintain a history of previous model states.

  • Stateless, deterministic parameterizations can use None as their auxiliary state.

Once you have implemented your parameterization, apply it to a base model using ParameterizedModel, which can then be used with a time stepper.

class pyqg_jax.parameterizations.ParameterizedModel(model, param_func, init_param_aux_func=None)[source]

A model wrapped in a user-specified parameterization.

Parameters:
  • model – The inner core model to wrap in the parameterization.

  • param_func (function) – The function implementing the parameterization. Will be called by get_updates() to compute time-stepping updates.

  • init_param_aux_func (function, optional) – The function used to initialize the parameterization’s auxiliary state. Defaults to a function initializing the state to None.

model

The inner model wrapped in the parameterization.

param_func

The user-specified parameterization function. Takes arguments (model_state, param_aux, model).

Type:

function

init_param_aux_func

Function used to initialize the auxiliary state. Takes arguments (model_state, model).

Type:

function

get_full_state(state)[source]

Expand a wrapped partial state into an unwrapped full state.

This function defers to model to compute the full state.

Parameters:

state (ParameterizedModelState) – The wrapped, parameterized state to be expanded.

Returns:

The expanded state. The real type depends on model, but is likely to be FullPseudoSpectralState.

Return type:

FullPseudoSpectralState

get_updates(state)[source]

Get updates for time-stepping state.

state is a wrapped, partial model state. This function returns updates for time-stepping.

This function makes use of param_func, applying the parameterization to the updates.

Parameters:

state (ParameterizedModelState) – The state which will be time stepped using the computed updates.

Returns:

A new state object where each field corresponds to a time-stepping update to be applied.

Return type:

ParameterizedModelState

Note

The object returned by this function has the same type of state, but contains updates. This is so the time-stepping can be done by mapping over the states and updates as JAX pytrees with the same structure.

postprocess_state(state)[source]

Apply fixed filtering to state.

This function should be called once on each new state after each time step.

SteppedModel handles this internally.

This function defers to model for the post-processing.

Parameters:

state (ParameterizedModelState) – The wrapped state to be filtered.

Returns:

The wrapped filtered state.

Return type:

ParameterizedModelState

create_initial_state(key, *args, **kwargs)[source]

Create a new wrapped initial state with random initialization.

This function defers to model to initialize the inner state and makes use of init_param_aux_func to initialize the parameterization’s auxiliary state.

Parameters:
  • key (jax.random.key) – The PRNG state used as the random key for initialization.

  • *args – Arbitrary additional arguments for init_param_aux_func

  • **kwargs – Arbitrary additional arguments for init_param_aux_func

Returns:

The new wrapped state with random initialization.

Return type:

ParameterizedModelState

initialize_param_state(state, *args, **kwargs)[source]

Wrap an existing state from model in a ParameterizedModelState.

This function takes an existing inner model state and wraps it so that it can be used with the parameterized model.

This function uses of init_param_aux_func to initialize the parameterization’s auxiliary state.

Parameters:
Returns:

A wrapped copy of state.

Return type:

ParameterizedModelState

class pyqg_jax.parameterizations.ParameterizedModelState(*, model_state, param_aux)[source]

Wrapped model state for parameterized models.

Warning

You should not construct this class yourself. Instead, you should obtain instances from ParameterizedModel.

model_state

The inner model state. The actual types depends on the inner model, but this is likely to be FullPseudoSpectralState.

Type:

PseudoSpectralState

param_aux

The auxiliary state for the parameterization. This is an arbitrary object time-stepped by the parameterization itself. It will be wrapped in a NoStepValue to shield it from the time steppers.

Type:

NoStepValue

We also provide decorators which can simplify the process of implementing common parameterizations in terms of velocity or potential vorticity.

@pyqg_jax.parameterizations.uv_parameterization[source]

Decorator implementing parameterizations in terms of velocity.

The target function should take as its first three arguments (state, param_aux, model) as with any other parameterization function. Additional arguments will be passed through unmodified.

This function should then return two values: (du, dv), param_aux. These values will then be added to the model’s original update value to form the parameterized update.

The wrapped function is suitable for use with ParameterizedModel.

See also: pyqg.parameterizations.UVParameterization

@pyqg_jax.parameterizations.q_parameterization[source]

Decorator implementing parameterizations in terms of potential vorticity.

The target function should take as its first three arguments (state, param_aux, model) as with any other parameterization function. Additional arguments will be passed through unmodified.

This function should then return two values: dq, param_aux. These values will then be added to the model’s original update value to form the parameterized update.

The wrapped function is suitable for use with ParameterizedModel.

See also: pyqg.parameterizations.QParameterization