Time-Steppers

Time-stepping schemes and utilities for using them to update model states.

To step a model through time, use SteppedModel and combine it with a time-stepper.

Available time-steppers included in this package are:

  • AB3Stepper Third-order Adams-Bashforth stepper (the same method used in PyQG)

  • EulerStepper Forward Euler first-order method

Combined Models

These utilities add time-stepping to base models.

class pyqg_jax.steppers.SteppedModel(model, stepper)[source]

Combine an inner model with a time stepper.

This class simplifies the process of stepping a base model through time by handling the interactions between the model and time stepper.

Parameters:
  • model – The inner model to step through time.

  • stepper – The time stepper applying the updates to the model each step.

model

The inner model being stepped.

stepper

The time stepper used to apply the updates.

create_initial_state(key, *args, **kwargs)[source]

Create a new wrapped initial state with random initialization.

This function defers to model to initialize the inner state, passing it any additional arguments. Then wraps it by calling initialize_stepper_state on stepper.

Parameters:
  • key (jax.random.key) – The PRNG state used as the random key for initialization.

  • *args – Arbitrary additional arguments for the model’s initialization function.

  • **kwargs – Arbitrary additional arguments for the model’s initialization function.

Returns:

The new wrapped state with random initialization.

Return type:

StepperState

initialize_stepper_state(state, /)[source]

Wrap an existing state from model in a StepperState, preparing it for time stepping.

This function takes an existing inner model state and wraps it so that it can be stepped through time by stepper.

This function defers to stepper to perform the wrapping.

Parameters:

state – The inner model state to wrap. The type depends on model but is likely to be PseudoSpectralState or ParameterizedModelState.

Returns:

A wrapped copy of state.

Return type:

StepperState

step_model(stepper_state, /)[source]

Update a state by computing the next time step.

This method handles the interaction between model and stepper, including post-processing/filtering.

To take multiple steps over time combine this method with jax.lax.scan().

Parameters:

stepper_state (StepperState) – The wrapped state to step forward in time.

Returns:

The updated wrapped state, a new object.

Return type:

StepperState

get_full_state(stepper_state, /)[source]

Expand a wrapped partial state into an unwrapped full state.

This function defers to model to compute the full state.

Parameters:

stepper_state (StepperState) – The wrapped state to be expanded.

Returns:

The expanded state. The real type depends on model, but is likely to be FullPseudoSpectralState.

Return type:

FullPseudoSpectralState

class pyqg_jax.steppers.StepperState(state, t, tc)[source]

Model state wrapped for time-stepping

Warning

You should not construct this class yourself. Instead, you should obtain instances from your chosen time stepper, or SteppedModel.

state

The inner state from the model being stepped forward. The actual type of state depends on the model being stepped.

Type:

PseudoSpectralState or ParameterizedModelState

t

The current model time

Type:

jax.numpy.float32

tc

The current model timestep

Type:

jax.numpy.uint32

update(**kwargs)[source]

Replace values stored in this state.

This function produces a new state object, containing the replacement values.

The keyword arguments may be any of state, t, or tc.

The object this method is called on is not modified.

Parameters:
Returns:

A copy of this object with the specified values replaced.

Return type:

StepperState

Time Stepping Schemes

Implemented time-stepping schemes (currently only AB3Stepper).

class pyqg_jax.steppers.AB3Stepper(dt)[source]

Third-order Adams-Bashforth stepper.

This is the same time stepping scheme as used in PyQG.

This time-stepper bootstraps using lower order Adams-Bashforth schemes for the first two steps.

Parameters:

dt (float) – Numerical time step

dt

Numerical time step

Type:

float

initialize_stepper_state(state)[source]

Wrap an existing state from a model in a StepperState to prepare it for time stepping.

This initializes a new StepperState from a time of 0.

Parameters:

state – The model state to wrap.

Returns:

The wrapped state. Note this will be a subclass of StepperState appropriate for this time stepper.

Return type:

StepperState

apply_updates(stepper_state, updates)[source]

Apply updates to the existing stepper_state producing the next step in time.

updates should be provided by the model that produced StepperState.state.

Parameters:
Returns:

The updated, wrapped state at the next time step.

Return type:

StepperState

Note

This method does not apply post-processing to the updated state.

class pyqg_jax.steppers.EulerStepper(dt)[source]

Forward Euler (first-order) stepper.

New in version 0.8.0.

Parameters:

dt (float) – Numerical time step

dt

Numerical time step

Type:

float

apply_updates(stepper_state, updates)[source]

Apply updates to the existing stepper_state producing the next step in time.

updates should be provided by the model that produced StepperState.state.

Parameters:
Returns:

The updated, wrapped state at the next time step.

Return type:

StepperState

Note

This method does not apply post-processing to the updated state.

initialize_stepper_state(state)[source]

Wrap an existing state from a model in a StepperState to prepare it for time stepping.

This initializes a new StepperState from a time of 0.

Parameters:

state – The model state to wrap.

Returns:

The wrapped state. Note this will be a subclass of StepperState appropriate for this time stepper.

Return type:

StepperState

State Manipulation

NoStepValue makes it possible to shield values from time-stepping so they can be updated manually.

class pyqg_jax.steppers.NoStepValue(value)[source]

Shields contents from the provided time-steppers.

When a time-stepper encounters a value wrapped in this class, it will skip its normal stepping computations and directly use the value from the updates. This allows a user to manually update an auxiliary value outside the normal time-stepping.

For example, jax.random.key() values should not be time-stepped normally. Wrapping them in this class and manually updating them can accomplish this.

This class is used as part of ParameterizedModelState.

Parameters:

value (object) – The inner value to wrap. This can be an arbitrary JAX PyTree.

value

The internal, wrapped value

Type:

pyqg_jax.steppers.P