Online Training

Other examples have illustrated some of the capabilities that are available with the implemented in JAX. One additional JAX transformation that has not yet been used is jax.grad(). Because the base models are implemented in JAX, we can take gradients through multiple simulation time steps, training a parameterization online through a live simulation, as opposed to training on static snapshots.

Some existing work has explored the impact of this training approach with QG models implemented in PyTorch. Here we provide a sketch of an online training setup using Equinox for neural networks and Optax for optimizers.

%env JAX_ENABLE_X64=True
import functools
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import optax
import matplotlib.pyplot as plt
import pyqg_jax
env: JAX_ENABLE_X64=True

To carry out our training we will make use of several elements demonstrated in other examples. In particular:

In addition to the high-resolution stepped model (size 128), the network, and the coarsening operator we also configure an Adam optimizer to train our network. For more information on combining these optimizers with Equinox, consult the Equinox documentation and the Optax documentation. Note that Optax provides a separate object, here optim_state, representing the state of the optimizer that must be updated as a part of training.

DT = 3600.0
LEARNING_RATE = 5e-4

big_model = pyqg_jax.steppers.SteppedModel(
    model=pyqg_jax.qg_model.QGModel(
        nx=128,
        ny=128,
        precision=pyqg_jax.state.Precision.SINGLE,
    ),
    stepper=pyqg_jax.steppers.AB3Stepper(dt=DT),
)

coarse_op = Operator1(big_model.model, 32)

# Ensure that all module weights are float32
net = module_to_single(NNParam(key=jax.random.key(123)))

optim = optax.adam(LEARNING_RATE)
optim_state = optim.init(eqx.filter(net, eqx.is_array))

With our network and optimizer initialized we generate several sample states to represent training data. These states are generated at the high resolution of size 128, and coarsened to the low resolution of size 32. These small states form our training targets, target_q. In a real application, these reference trajectories would likely be pre-computed and loaded from disk.

Note that we do not generate any explicit forcing targets here since we will be supervising on the states directly.

@functools.partial(jax.jit, static_argnames=["num_steps"])
def generate_train_data(seed, num_steps):

    def step(carry, _x):
        next_state = big_model.step_model(carry)
        small_state = coarse_op.coarsen_state(carry.state)
        return next_state, small_state.q

    _final_big_state, target_q = jax.lax.scan(
        step, big_model.create_initial_state(jax.random.key(seed)), None, length=num_steps
    )
    return target_q

target_q = generate_train_data(123, num_steps=100)

Next we provide a function to roll out a trajectory starting from some initial state. In this case we provide the state as a bare JAX array and have to package it into a model state. Another example of this process is included in “Basic Time Stepping.” See “Implementing a Parameterization” for another example of using a neural network parameterization.

def roll_out_with_net(init_q, net, num_steps):

    @pyqg_jax.parameterizations.q_parameterization
    def net_parameterization(state, param_aux, model):
        assert param_aux is None
        q = state.q
        # Scale states to improve stability
        # This 1e-6 is for illustration only
        q_in = (q / 1e-6).astype(jnp.float32)
        q_param = net(q.astype(jnp.float32))
        return 1e-6 * q_param.astype(q.dtype), None

    # Extrace the small model from the coarsener
    # Then wrap it in the network parameterization and stepper
    # Make sure to match time steps
    small_model = pyqg_jax.steppers.SteppedModel(
        model=pyqg_jax.parameterizations.ParameterizedModel(
            model=coarse_op.small_model,
            param_func=net_parameterization,
        ),
        stepper=pyqg_jax.steppers.AB3Stepper(dt=DT),
    )
    # Package our state
    # First, package it for the base model
    base_state = small_model.model.model.create_initial_state(
        jax.random.key(0)
    ).update(q=init_q)
    # Next, wrap it for the parameterization and stepper
    init_state = small_model.initialize_stepper_state(
        small_model.model.initialize_param_state(base_state)
    )

    def step(carry, _x):
        next_state = small_model.step_model(carry)
        # NOTE: Be careful! We output the *old* state for the trajectory
        # Otherwise the initial step would be skipped
        return next_state, carry.state.model_state.q

    # Roll out the state
    _final_step, traj = jax.lax.scan(
        step, init_state, None, length=num_steps
    )
    return traj

We provide a function using the above to roll out a trajectory at the low resolution and compute errors against the reference trajectory target_q. In this case we use a simple MSE loss for training. We also use Equinox’s “filtered” transforms (equinox.filter_jit(), equinox.filter_value_and_grad()) since these interact more naturally with the Equinox neural network modules.

Note

Online training with long rollouts may lead to out-of-memory errors. One solution is to use jax.checkpoint() inside the scan to save memory through recomputation.

An implementation of this is available in powerpax.checkpoint_chunked_scan(), or see this sample code for a starting point.

def compute_traj_errors(target_q, net):
    rolled_out = roll_out_with_net(
        init_q=target_q[0],
        net=net,
        num_steps=target_q.shape[0],
    )
    err = rolled_out - target_q
    return err

@eqx.filter_jit
def train_batch(batch, net, optim_state):

    def loss_fn(net, batch):
        err = jax.vmap(functools.partial(compute_traj_errors, net=net))(batch)
        mse = jnp.mean(err**2)
        return mse

    # Compute loss value and gradients
    loss, grads = eqx.filter_value_and_grad(loss_fn)(net, batch)
    # Update the network weights
    updates, new_optim_state = optim.update(grads, optim_state, net)
    new_net = eqx.apply_updates(net, updates)
    # Return the loss, updated net, updated optimizer state
    return loss, new_net, new_optim_state

We use the components we have to run a short training loop and report the loss after each step. The training steps are all JIT compiled.

For the training function above, batch has shape (batch_size, num_time_steps, nz, ny, nx). Each batch should have at least two time steps otherwise the parameterization will not be evaluated in the resulting trajectory because in the sample here the evaluated trajectory includes the unmodified initial step.

BATCH_SIZE = 8
BATCH_STEPS = 10
assert BATCH_STEPS >= 2

np_rng = np.random.default_rng(seed=456)
losses = []
for batch_i in range(30):
    # Rudimentary shuffling in lieu of real data loader
    batch = np.stack(
        [
            target_q[start:start+BATCH_STEPS]
            for start in np_rng.integers(
                0, target_q.shape[0] - BATCH_STEPS, size=BATCH_SIZE
            )
        ]
    )
    loss, net, optim_state = train_batch(batch, net, optim_state)
    losses.append(loss)
    print(f"Step {batch_i + 1:02}: loss={loss.item():.4E}")
Step 01: loss=7.2966E-07
Step 02: loss=6.6778E-07
Step 03: loss=6.0980E-07
Step 04: loss=5.5482E-07
Step 05: loss=5.0220E-07
Step 06: loss=4.5213E-07
Step 07: loss=4.0494E-07
Step 08: loss=3.6078E-07
Step 09: loss=3.1955E-07
Step 10: loss=2.8105E-07
Step 11: loss=2.4526E-07
Step 12: loss=2.1221E-07
Step 13: loss=1.8199E-07
Step 14: loss=1.5456E-07
Step 15: loss=1.2986E-07
Step 16: loss=1.0774E-07
Step 17: loss=8.8077E-08
Step 18: loss=7.0782E-08
Step 19: loss=5.5785E-08
Step 20: loss=4.2992E-08
Step 21: loss=3.2266E-08
Step 22: loss=2.3426E-08
Step 23: loss=1.6280E-08
Step 24: loss=1.0662E-08
Step 25: loss=6.4228E-09
Step 26: loss=3.4122E-09
Step 27: loss=1.4665E-09
Step 28: loss=4.0288E-10
Step 29: loss=3.9441E-11
Step 30: loss=2.1493E-10
plt.plot(np.arange(len(losses)) + 1, losses)
plt.xlabel("Step")
plt.ylabel("Step Loss")
plt.grid(True)
_images/62661737c051324ae7366f8973e876afd954c1e1a623689ac8d0826458206977.png