Implementing a Parameterization

In addition to the built-in parameterizations provided in this package, it is possible to use your own code to implement a custom parameterization. For example, a parameterization implemented by a neural network. Additional details on custom parameterizations are available in the documentation section “Implementing Parameterizations”.

We illustrate the creation of a simple neural network parameterization. For the neural network itself we use Equinox, but other libraries such as Flax could be used as well.

%env JAX_ENABLE_X64=True
import functools
import jax
import jax.numpy as jnp
import equinox as eqx
import pyqg_jax
env: JAX_ENABLE_X64=True

In our example we will use a small convolutional network. The network here is randomly initialized, but in real use it would likely use trained weights loaded from a file. Also, the architecture here has a padding size configured to keep the state sizes constant. Because system states are periodic, this example uses circular padding (see the padding_mode parameter of equinox.nn.Conv). Periodic padding can also be added using the "wrap" mode of jax.numpy.pad().

def param_to_single(param):
    if eqx.is_inexact_array(param):
        if param.dtype == jnp.dtype(jnp.float64):
            return param.astype(jnp.float32)
        elif param.dtype == jnp.dtype(jnp.complex128):
            return param.astype(jnp.complex64)
    return param

def module_to_single(module):
    return jax.tree_util.tree_map(param_to_single, module)

class NNParam(eqx.Module):
    ops: eqx.nn.Sequential

    def __init__(self, key):
        key1, key2 = jax.random.split(key, 2)
        self.ops = eqx.nn.Sequential(
            [
                eqx.nn.Conv2d(
                    in_channels=2,
                    out_channels=5,
                    kernel_size=3,
                    padding="SAME",
                    key=key1,
                    padding_mode="CIRCULAR",
                ),
                eqx.nn.Lambda(jax.nn.relu),
                eqx.nn.Conv2d(
                    in_channels=5,
                    out_channels=2,
                    kernel_size=3,
                    padding="SAME",
                    key=key2,
                    padding_mode="CIRCULAR",
                ),
            ]
        )

    def __call__(self, x, *, key=None):
        return self.ops(x, key=key)

# Ensure that all module weights are float32
net = module_to_single(NNParam(key=jax.random.key(123)))

net
NNParam(
  ops=Sequential(
    layers=(
      Conv2d(
        num_spatial_dims=2,
        weight=f32[5,2,3,3],
        bias=f32[5,1,1],
        in_channels=2,
        out_channels=5,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding='SAME',
        dilation=(1, 1),
        groups=1,
        use_bias=True,
        padding_mode='CIRCULAR'
      ),
      Lambda(fn=<wrapped function relu>),
      Conv2d(
        num_spatial_dims=2,
        weight=f32[2,5,3,3],
        bias=f32[2,1,1],
        in_channels=5,
        out_channels=2,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding='SAME',
        dilation=(1, 1),
        groups=1,
        use_bias=True,
        padding_mode='CIRCULAR'
      )
    )
  )
)

Next, we write a function wrapping the network so that it is suitable for use with ParameterizedModel. We illustrate a parameterization updating dqdt and so use q_parameterization to decorate our function.

This parameterization largely just evaluates the network. However, because the network weights are float32 while the simulation state is float64 we add casting around the network as needed.

Finally, note that this parameterization is stateless and depends only on the current model state. The param_aux value is always None and we always return None as the updated state value.

@pyqg_jax.parameterizations.q_parameterization
def net_parameterization(state, param_aux, model):
    assert param_aux is None
    q = state.q
    q_param = net(q.astype(jnp.float32))
    return q_param.astype(q.dtype), None

Next we construct our base QGModel wrapped in a ParameterizedModel. Because our parameterization is stateless we can use the default value for init_param_aux_func which initializes the state to None.

model = pyqg_jax.steppers.SteppedModel(
    model=pyqg_jax.parameterizations.ParameterizedModel(
        model=pyqg_jax.qg_model.QGModel(
            nx=32,
            ny=32,
            precision=pyqg_jax.state.Precision.DOUBLE,
        ),
        param_func=net_parameterization,
    ),
    stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
)

model
SteppedModel(
  model=ParameterizedModel(
    model=QGModel(
      nx=32,
      ny=32,
      L=1000000.0,
      W=1000000.0,
      rek=5.787e-07,
      filterfac=23.6,
      f=None,
      g=9.81,
      beta=1.5e-11,
      rd=15000.0,
      delta=0.25,
      H1=500,
      U1=0.025,
      U2=0.0,
      precision=<Precision.DOUBLE: 2>,
    ),
    param_func=jax.tree_util.Partial(<function __main__.net_parameterization>),
    init_param_aux_func=jax.tree_util.Partial(<function pyqg_jax.parameterizations._parameterized_model._init_none>),
  ),
  stepper=AB3Stepper(dt=3600.0),
)

As in Basic Time Stepping we now write a JIT compiled function to roll out a trajectory from an initial state. Because the states and parameterized model are all JAX PyTrees they can be passed as arguments through the function.

@functools.partial(jax.jit, static_argnames=["num_steps"])
def roll_out_state(state, stepped_model, num_steps):

    def loop_fn(carry, _x):
        current_state = carry
        next_state = stepped_model.step_model(current_state)
        return next_state, current_state

    _final_state, traj = jax.lax.scan(
        loop_fn, state, None, length=num_steps
    )
    return traj

We initialize our model state. Note the added NoStepValue wrapping the None state. This interacts with the time stepper so that the auxiliary state values are not time-stepped. It is up to your parameterization to provide new values for these, as needed.

init_state = model.create_initial_state(jax.random.key(0))

init_state
AB3State(
  t=f32[],
  tc=u32[],
  state=ParameterizedModelState(
    model_state=PseudoSpectralState(qh=c128[2,32,17]),
    param_aux=NoStepValue(value=None),
  ),
)

Finally we roll out the resulting trajectory.

traj = roll_out_state(init_state, model, num_steps=10)

traj
AB3State(
  t=f32[10],
  tc=u32[10],
  state=ParameterizedModelState(
    model_state=PseudoSpectralState(qh=c128[10,2,32,17]),
    param_aux=NoStepValue(value=None),
  ),
)

Notice how—as in past examples—we have a new leading time dimension of size 10.

Stateful Parameterizations

The parameterization above was stateless which is to say that its value depended only on the model state and no on any history of previous steps. No all have this structure, particularly in JAX and this package has facilities to integrate stateful trajectories into the time-stepped models.

Here we illustrate an extension of the neural network parameterization above, managing an additional auxiliary state value. These can be arbitrary JAX PyTrees which providing flexibility for a variety of use cases. In particular here we use a nested tuple of several values.

Our auxiliary state will have two components:

  1. a key to provide randomness for use with the parameterization

  2. a two-step shift register of past parameterization outputs

These illustrate two different values that one might need. The first illustrates how to manage random states for use in stochastic parameterizations, and the second illustrate one possible approach to implementing parameterizations that depend on a history of previous states.

The first step required is to provide code to initialize our param_aux values. Here, our function takes on additional argument seed which we use to construct a key. We also produce two placeholders for the past model states, in this case arrays with the proper shapes filled with zeros.

def net_init_aux(model_state, model, seed):
    rng = jax.random.key(seed)
    init_state = jnp.zeros_like(model_state.q, dtype=jnp.float32)
    init_states = (init_state, init_state)
    return rng, init_states

Next we extend net_parameterization, defined above. This new version makes use of the param_aux argument, unpacking it to retrieve previous states. As before we cast the current state to float32 before handing it to the network and we split the RNG to provide a separate state to pass to our network. The second step, we shift the past states, dropping the oldest and producing new_states.

This function returns the parameterization output as its first return value, and the new param_aux value as its second.

@pyqg_jax.parameterizations.q_parameterization
def net_key_parameterization(state, param_aux, model):
    old_rng, (pp_param, p_param) = param_aux
    rng, new_rng = jax.random.split(old_rng)
    orig_dtype = state.q.dtype
    q_param = net(state.q.astype(jnp.float32), key=rng)
    out_param = jnp.mean(jnp.stack([q_param, pp_param, p_param]), axis=0)
    new_states = (p_param, q_param)
    return out_param.astype(orig_dtype), (new_rng, new_states)

Now we can use both of these functions as arguments to ParameterizedModel.

state_model = pyqg_jax.steppers.SteppedModel(
    model=pyqg_jax.parameterizations.ParameterizedModel(
        model=pyqg_jax.qg_model.QGModel(
            nx=32,
            ny=32,
            precision=pyqg_jax.state.Precision.DOUBLE,
        ),
        param_func=net_key_parameterization,
        init_param_aux_func=net_init_aux,
    ),
    stepper=pyqg_jax.steppers.AB3Stepper(dt=14400.0),
)

We can create a new model state with the parameterization values. Note that we provide the additional seed argument which is passed through to net_init_aux.

init_key_state = state_model.create_initial_state(jax.random.key(0), seed=10)

init_key_state
AB3State(
  t=f32[],
  tc=u32[],
  state=ParameterizedModelState(
    model_state=PseudoSpectralState(qh=c128[2,32,17]),
    param_aux=NoStepValue(value=(key<fry>[], (f32[2,32,32], f32[2,32,32]))),
  ),
)

Notice that the param_aux value wrapped inside the NoStepValue object is no longer None. The key array is visible, along with two float32 arrays for past parameterization outputs.

Finally, as before we can roll out a trajectory:

state_traj = roll_out_state(init_key_state, state_model, num_steps=10)

state_traj
AB3State(
  t=f32[10],
  tc=u32[10],
  state=ParameterizedModelState(
    model_state=PseudoSpectralState(qh=c128[10,2,32,17]),
    param_aux=NoStepValue(value=(key<fry>[10], (f32[10,2,32,32], f32[10,2,32,32]))),
  ),
)

Note that the parameterization states also have time dimensions. However, their values are not time-stepped by the AB3Stepper, but instead the updated values produced by net_key_parameterization are used directly and are left untouched by the time-stepper.