Time-Steppers
Time-stepping schemes and utilities for using them to update model states.
To step a model through time, use SteppedModel and combine it
with a time-stepper.
Available time-steppers included in this package are:
AB3StepperThird-order Adams-Bashforth stepper (the same method used in PyQG)EulerStepperForward Euler first-order method
Combined Models
These utilities add time-stepping to base models.
- class pyqg_jax.steppers.SteppedModel(model, stepper)[source]
Combine an inner model with a time stepper.
This class simplifies the process of stepping a base model through time by handling the interactions between the model and time stepper.
- Parameters:
model – The inner model to step through time.
stepper – The time stepper applying the updates to the model each step.
- model
The inner model being stepped.
- stepper
The time stepper used to apply the updates.
- create_initial_state(key, *args, **kwargs)[source]
Create a new wrapped initial state with random initialization.
This function defers to
modelto initialize the inner state, passing it any additional arguments. Then wraps it by calling initialize_stepper_state onstepper.- Parameters:
key (jax.random.key) – The PRNG state used as the random key for initialization.
*args – Arbitrary additional arguments for the model’s initialization function.
**kwargs – Arbitrary additional arguments for the model’s initialization function.
- Returns:
The new wrapped state with random initialization.
- Return type:
- initialize_stepper_state(state, /)[source]
Wrap an existing state from
modelin aStepperState, preparing it for time stepping.This function takes an existing inner model state and wraps it so that it can be stepped through time by
stepper.This function defers to
stepperto perform the wrapping.- Parameters:
state – The inner model state to wrap. The type depends on
modelbut is likely to bePseudoSpectralStateorParameterizedModelState.- Returns:
A wrapped copy of state.
- Return type:
- step_model(stepper_state, /)[source]
Update a state by computing the next time step.
This method handles the interaction between
modelandstepper, including post-processing/filtering.To take multiple steps over time combine this method with
jax.lax.scan().- Parameters:
stepper_state (StepperState) – The wrapped state to step forward in time.
- Returns:
The updated wrapped state, a new object.
- Return type:
- get_full_state(stepper_state, /)[source]
Expand a wrapped partial state into an unwrapped full state.
This function defers to
modelto compute the full state.- Parameters:
stepper_state (StepperState) – The wrapped state to be expanded.
- Returns:
The expanded state. The real type depends on
model, but is likely to beFullPseudoSpectralState.- Return type:
- class pyqg_jax.steppers.StepperState(*, state, t, tc)[source]
Model state wrapped for time-stepping
Warning
You should not construct this class yourself. Instead, you should obtain instances from your chosen time stepper, or
SteppedModel.- state
The inner state from the model being stepped forward. The actual type of state depends on the model being stepped.
- t
The current model time
- Type:
- tc
The current model timestep
- Type:
- update(**kwargs)[source]
Replace values stored in this state.
This function produces a new state object, containing the replacement values.
The keyword arguments may be any of state, t, or tc.
The object this method is called on is not modified.
- Parameters:
state (PseudoSpectralState or ParameterizedModelState, optional) – Replacement value for
state.t (jax.numpy.float32, optional) – Replacement value for
t. The current model timetc (jax.numpy.uint32, optional) – Replacement value for
tc.
- Returns:
A copy of this object with the specified values replaced.
- Return type:
Time Stepping Schemes
Implemented time-stepping schemes (currently only AB3Stepper).
- class pyqg_jax.steppers.AB3Stepper(dt)[source]
Third-order Adams-Bashforth stepper.
This is the same time stepping scheme as used in PyQG.
This time-stepper bootstraps using lower order Adams-Bashforth schemes for the first two steps.
- Parameters:
dt (float) – Numerical time step
- initialize_stepper_state(state)[source]
Wrap an existing state from a model in a
StepperStateto prepare it for time stepping.This initializes a new
StepperStatefrom a time of0.- Parameters:
state – The model state to wrap.
- Returns:
The wrapped state. Note this will be a subclass of
StepperStateappropriate for this time stepper.- Return type:
- apply_updates(stepper_state, updates)[source]
Apply updates to the existing stepper_state producing the next step in time.
updates should be provided by the model that produced
StepperState.state.- Parameters:
stepper_state (StepperState) – The time-stepper wrapped state to be updated.
updates (PseudoSpectralState or ParameterizedModelState) – The unwrapped updates to apply. The actual type of updates depends on the model being stepped.
- Returns:
The updated, wrapped state at the next time step.
- Return type:
Note
This method does not apply post-processing to the updated state.
- class pyqg_jax.steppers.EulerStepper(dt)[source]
Forward Euler (first-order) stepper.
Added in version 0.8.0.
- Parameters:
dt (float) – Numerical time step
- apply_updates(stepper_state, updates)[source]
Apply updates to the existing stepper_state producing the next step in time.
updates should be provided by the model that produced
StepperState.state.- Parameters:
stepper_state (StepperState) – The time-stepper wrapped state to be updated.
updates (PseudoSpectralState or ParameterizedModelState) – The unwrapped updates to apply. The actual type of updates depends on the model being stepped.
- Returns:
The updated, wrapped state at the next time step.
- Return type:
Note
This method does not apply post-processing to the updated state.
- initialize_stepper_state(state)[source]
Wrap an existing state from a model in a
StepperStateto prepare it for time stepping.This initializes a new
StepperStatefrom a time of0.- Parameters:
state – The model state to wrap.
- Returns:
The wrapped state. Note this will be a subclass of
StepperStateappropriate for this time stepper.- Return type:
State Manipulation
NoStepValue makes it possible to shield values from
time-stepping so they can be updated manually.
- class pyqg_jax.steppers.NoStepValue(value)[source]
Shields contents from the provided time-steppers.
When a time-stepper encounters a value wrapped in this class, it will skip its normal stepping computations and directly use the value from the updates. This allows a user to manually update an auxiliary value outside the normal time-stepping.
For example,
jax.random.key()values should not be time-stepped normally. Wrapping them in this class and manuallyupdating themcan accomplish this.This class is used as part of
ParameterizedModelState.- Parameters:
value (object) – The inner value to wrap. This can be an arbitrary JAX PyTree.
- value
The internal, wrapped value
- Type:
pyqg_jax.steppers.P