Batches of Trajectories

Because the models and time steppers are fully implemented in JAX we can use various transforms to manipulate the models and trajectories. In this example we use jax.vmap() to run several trajectories at once.

If you are running on a GPU using small state dimensions, batching several trajectories can make better use of your GPU’s compute capacity.

%env JAX_ENABLE_X64=True
import functools
import matplotlib.pyplot as plt
import cmocean.cm as cmo
import jax
import jax.numpy as jnp
import pyqg_jax
env: JAX_ENABLE_X64=True

We begin by setting up a time stepped model as in the basic time stepping example:

model = pyqg_jax.steppers.SteppedModel(
    model=pyqg_jax.qg_model.QGModel(
        nx=64,
        ny=64,
        precision=pyqg_jax.state.Precision.DOUBLE,
    ),
    stepper=pyqg_jax.steppers.AB3Stepper(dt=14400.0),
)

Next we can use vmap to create our initial states from a stack of key objects:

# Split one initial RNG key into three, then stack and vmap
init_rngs = jnp.stack(jax.random.split(jax.random.key(0), 3))
init_states = jax.vmap(model.create_initial_state)(init_rngs)

init_states
AB3State(
  t=f32[3],
  tc=u32[3],
  state=PseudoSpectralState(qh=c128[3,2,64,33]),
)

Note the leading dimension of size 3, one for each initial configuration.

We include our roll_out_state function that we used for basic stepping, but we apply vmap before jit, making sure to specify that the num_steps argument should not be batched.

This time, however, we modify our scan loop to keep only the final state.

@functools.partial(jax.jit, static_argnames=["num_steps"])
@functools.partial(jax.vmap, in_axes=(0, None))
def roll_out_state(state, num_steps):

    def loop_fn(carry, _x):
        current_state = carry
        next_state = model.step_model(current_state)
        return next_state, None

    final_state, _ = jax.lax.scan(
        loop_fn, state, None, length=num_steps
    )
    return final_state

We can now roll out all three trajectories at the same time:

# Note that the vmap decorator prevents passing num_steps
# as a keyword argument
final_steps = roll_out_state(init_states, 7500)

final_steps
AB3State(
  t=f32[3],
  tc=u32[3],
  state=PseudoSpectralState(qh=c128[3,2,64,33]),
)

Note that we now have three final states, one for each trajectory in the batch.

Note

Note that vmap causes us to be unable to pass num_steps as a keyword/named argument (see JAX#7465).

We can plot each of their final steps:

final_q = final_steps.state.q

vmax = jnp.abs(final_q[:, 0]).max()
fig, axs = plt.subplots(1, final_q.shape[0], layout="constrained")
for i, (single_q, ax) in enumerate(zip(final_q, axs)):
    ax.set_title(f"Trajectory {i}")
    ax.imshow(single_q[0], cmap=cmo.balance, vmin=-vmax, vmax=vmax)
_images/0fdf77487a2ca4946d7d83255c664feba22b0453cdc398a992556ec66b04253d.png

Notice that each trajectory has evolved separately and produced a unique state.

Batching Models

Because both the states and models are JAX objects, it is also possible to run multiple models in a vmap.

reks = jnp.array([5.787e-7, 7e-08])
deltas = jnp.array([0.25, 0.1])
betas = jnp.array([1.5e-11, 1e-11])

def make_model(rek, delta, beta):
    model = pyqg_jax.steppers.SteppedModel(
        model=pyqg_jax.qg_model.QGModel(
            nx=64,
            ny=64,
            precision=pyqg_jax.state.Precision.DOUBLE,
            rek=rek,
            delta=delta,
            beta=beta,
        ),
        stepper=pyqg_jax.steppers.AB3Stepper(dt=14400.0),
    )
    return model

models = jax.vmap(make_model)(reks, deltas, betas)

models
SteppedModel(
  model=QGModel(
    nx=64,
    ny=64,
    L=f64[2],
    W=f64[2],
    rek=f64[2],
    filterfac=f64[2],
    f=None,
    g=f64[2],
    beta=f64[2],
    rd=f64[2],
    delta=f64[2],
    H1=i64[2],
    U1=f64[2],
    U2=f64[2],
    precision=<Precision.DOUBLE: 2>,
  ),
  stepper=AB3Stepper(dt=f64[2]),
)

Note

You can vary parameters between the models in a batch except for parameters which affect the shape or dtype of the values. In particular nx, ny, nz, and precision must be the same in each member of the ensemble.

The batched model’s methods must be called inside a vmap in order to function properly. We run both models on the same initial state.

def make_initial_state(model, rng):
    return model.create_initial_state(rng)

# Call the function with a constant RNG key (seeded to zero) but different stacked models.
# It would also be possible to provide different RNG keys for each model as was done above.
batch_state = jax.vmap(functools.partial(make_initial_state, rng=jax.random.key(0)))(
    models
)

batch_state
AB3State(
  t=f32[2],
  tc=u32[2],
  state=PseudoSpectralState(qh=c128[2,2,64,33]),
)

The leading dimension of size 2 is the batch dimension. We can now set up our code to roll these out, each with a separate model.

Both initial states are identical:

vmax = jnp.abs(batch_state.state.q[:, 0]).max()
fig, axs = plt.subplots(1, batch_state.state.q.shape[0], layout="constrained")
for i, (single_q, ax) in enumerate(zip(batch_state.state.q, axs)):
    ax.set_title(f"Trajectory {i}")
    ax.imshow(single_q[0], cmap=cmo.balance, vmin=-vmax, vmax=vmax)
_images/628bd7e74e0df61731c117d35f9b3a353d1bf5ce1831b33abf9bae463956df88.png

We now rework our roll_out_state function to accept the models as an additional argument and use vmap to add the batch dimension.

@functools.partial(jax.jit, static_argnames=["num_steps"])
@functools.partial(jax.vmap, in_axes=(0, 0, None))
def roll_out_batch_models(model, state, num_steps):

    def loop_fn(carry, _x):
        current_state = carry
        next_state = model.step_model(current_state)
        return next_state, None

    final_state, _ = jax.lax.scan(
        loop_fn, state, None, length=num_steps
    )
    return final_state

batch_model_final = roll_out_batch_models(
    models, batch_state, 7500
)

batch_model_final
AB3State(
  t=f32[2],
  tc=u32[2],
  state=PseudoSpectralState(qh=c128[2,2,64,33]),
)

Plotting the final steps shows the impact of the different model parameters, we see that the second model has produced a trajectory that has not yet finished warmup.

final_q = batch_model_final.state.q

vmax = jnp.abs(batch_model_final.state.q[:, 0]).max()
fig, axs = plt.subplots(1, batch_model_final.state.q.shape[0], layout="constrained")
for i, (single_q, ax) in enumerate(zip(batch_model_final.state.q, axs)):
    ax.set_title(f"Trajectory {i}")
    ax.imshow(single_q[0], cmap=cmo.balance, vmin=-vmax, vmax=vmax)
_images/271088101694cd7a0d0070c1e3f038e03d468c0c2b67d9b1a15854b2f95cab83.png