Implementing a Parameterization
In addition to the built-in parameterizations provided in this package, it is possible to use your own code to implement a custom parameterization. For example, a parameterization implemented by a neural network. Additional details on custom parameterizations are available in the documentation section “Implementing Parameterizations”.
We illustrate the creation of a simple neural network parameterization. For the neural network itself we use Equinox, but other libraries such as Flax could be used as well.
%env JAX_ENABLE_X64=True
import functools
import jax
import jax.numpy as jnp
import equinox as eqx
import pyqg_jax
env: JAX_ENABLE_X64=True
In our example we will use a small convolutional network. The network
here is randomly initialized, but in real use it would likely use
trained weights loaded from a file. Also, the architecture here has a
padding size configured to keep the state sizes constant. Because
system states are periodic, this example uses circular padding (see
the padding_mode
parameter of equinox.nn.Conv
). Periodic
padding can also be added using the "wrap"
mode of
jax.numpy.pad()
.
def param_to_single(param):
if eqx.is_inexact_array(param):
if param.dtype == jnp.dtype(jnp.float64):
return param.astype(jnp.float32)
elif param.dtype == jnp.dtype(jnp.complex128):
return param.astype(jnp.complex64)
return param
def module_to_single(module):
return jax.tree_util.tree_map(param_to_single, module)
class NNParam(eqx.Module):
ops: eqx.nn.Sequential
def __init__(self, key):
key1, key2 = jax.random.split(key, 2)
self.ops = eqx.nn.Sequential(
[
eqx.nn.Conv2d(
in_channels=2,
out_channels=5,
kernel_size=3,
padding="SAME",
key=key1,
padding_mode="CIRCULAR",
),
eqx.nn.Lambda(jax.nn.relu),
eqx.nn.Conv2d(
in_channels=5,
out_channels=2,
kernel_size=3,
padding="SAME",
key=key2,
padding_mode="CIRCULAR",
),
]
)
def __call__(self, x, *, key=None):
return self.ops(x, key=key)
# Ensure that all module weights are float32
net = module_to_single(NNParam(key=jax.random.key(123)))
net
NNParam(
ops=Sequential(
layers=(
Conv2d(
num_spatial_dims=2,
weight=f32[5,2,3,3],
bias=f32[5,1,1],
in_channels=2,
out_channels=5,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
),
Lambda(fn=<wrapped function relu>),
Conv2d(
num_spatial_dims=2,
weight=f32[2,5,3,3],
bias=f32[2,1,1],
in_channels=5,
out_channels=2,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
)
)
)
)
Next, we write a function wrapping the network so that it is suitable
for use with ParameterizedModel
. We illustrate a
parameterization updating dqdt
and so use q_parameterization
to decorate our
function.
This parameterization largely just evaluates the network. However,
because the network weights are float32
while the simulation state
is float64
we add casting around the network as needed.
Finally, note that this parameterization is stateless and depends
only on the current model state. The param_aux
value is always
None
and we always return None
as the updated state value.
@pyqg_jax.parameterizations.q_parameterization
def net_parameterization(state, param_aux, model):
assert param_aux is None
q = state.q
q_param = net(q.astype(jnp.float32))
return q_param.astype(q.dtype), None
Next we construct our base QGModel
wrapped in a ParameterizedModel
. Because our
parameterization is stateless we can use the default value for
init_param_aux_func
which initializes the state to None
.
model = pyqg_jax.steppers.SteppedModel(
model=pyqg_jax.parameterizations.ParameterizedModel(
model=pyqg_jax.qg_model.QGModel(
nx=32,
ny=32,
precision=pyqg_jax.state.Precision.DOUBLE,
),
param_func=net_parameterization,
),
stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
)
model
SteppedModel(
model=ParameterizedModel(
model=QGModel(
nx=32,
ny=32,
L=1000000.0,
W=1000000.0,
rek=5.787e-07,
filterfac=23.6,
f=None,
g=9.81,
beta=1.5e-11,
rd=15000.0,
delta=0.25,
H1=500,
U1=0.025,
U2=0.0,
precision=<Precision.DOUBLE: 2>,
),
param_func=jax.tree_util.Partial(<function __main__.net_parameterization>),
init_param_aux_func=jax.tree_util.Partial(<function pyqg_jax.parameterizations._parameterized_model._init_none>),
),
stepper=AB3Stepper(dt=3600.0),
)
As in Basic Time Stepping we now write a JIT compiled function to roll out a trajectory from an initial state. Because the states and parameterized model are all JAX PyTrees they can be passed as arguments through the function.
@functools.partial(jax.jit, static_argnames=["num_steps"])
def roll_out_state(state, stepped_model, num_steps):
def loop_fn(carry, _x):
current_state = carry
next_state = stepped_model.step_model(current_state)
return next_state, current_state
_final_state, traj = jax.lax.scan(
loop_fn, state, None, length=num_steps
)
return traj
We initialize our model state. Note the added NoStepValue
wrapping the None
state. This
interacts with the time stepper
so that the auxiliary state values are
not time-stepped. It is up to your parameterization to provide new
values for these, as needed.
init_state = model.create_initial_state(jax.random.key(0))
init_state
AB3State(
t=f32[],
tc=u32[],
state=ParameterizedModelState(
model_state=PseudoSpectralState(qh=c128[2,32,17]),
param_aux=NoStepValue(value=None),
),
)
Finally we roll out the resulting trajectory.
traj = roll_out_state(init_state, model, num_steps=10)
traj
AB3State(
t=f32[10],
tc=u32[10],
state=ParameterizedModelState(
model_state=PseudoSpectralState(qh=c128[10,2,32,17]),
param_aux=NoStepValue(value=None),
),
)
Notice how—as in past examples—we have a new leading time dimension of size 10.
Stateful Parameterizations
The parameterization above was stateless which is to say that its value depended only on the model state and no on any history of previous steps. No all have this structure, particularly in JAX and this package has facilities to integrate stateful trajectories into the time-stepped models.
Here we illustrate an extension of the neural network parameterization above, managing an additional auxiliary state value. These can be arbitrary JAX PyTrees which providing flexibility for a variety of use cases. In particular here we use a nested tuple of several values.
Our auxiliary state will have two components:
a
key
to provide randomness for use with the parameterizationa two-step shift register of past parameterization outputs
These illustrate two different values that one might need. The first illustrates how to manage random states for use in stochastic parameterizations, and the second illustrate one possible approach to implementing parameterizations that depend on a history of previous states.
The first step required is to provide code to initialize our
param_aux
values. Here, our function takes on additional argument
seed
which we use to construct a key
. We also produce two
placeholders for the past model states, in this case arrays with the
proper shapes filled with zeros.
def net_init_aux(model_state, model, seed):
rng = jax.random.key(seed)
init_state = jnp.zeros_like(model_state.q, dtype=jnp.float32)
init_states = (init_state, init_state)
return rng, init_states
Next we extend net_parameterization
, defined above. This new version
makes use of the param_aux
argument, unpacking it to retrieve
previous states. As before we cast the current state to float32
before handing it to the network and we split
the RNG to provide a separate state to pass to our
network. The second step, we shift the past states, dropping the
oldest and producing new_states
.
This function returns the parameterization output as its first return
value, and the new param_aux
value as its second.
@pyqg_jax.parameterizations.q_parameterization
def net_key_parameterization(state, param_aux, model):
old_rng, (pp_param, p_param) = param_aux
rng, new_rng = jax.random.split(old_rng)
orig_dtype = state.q.dtype
q_param = net(state.q.astype(jnp.float32), key=rng)
out_param = jnp.mean(jnp.stack([q_param, pp_param, p_param]), axis=0)
new_states = (p_param, q_param)
return out_param.astype(orig_dtype), (new_rng, new_states)
Now we can use both of these functions as arguments to
ParameterizedModel
.
state_model = pyqg_jax.steppers.SteppedModel(
model=pyqg_jax.parameterizations.ParameterizedModel(
model=pyqg_jax.qg_model.QGModel(
nx=32,
ny=32,
precision=pyqg_jax.state.Precision.DOUBLE,
),
param_func=net_key_parameterization,
init_param_aux_func=net_init_aux,
),
stepper=pyqg_jax.steppers.AB3Stepper(dt=14400.0),
)
We can create a new model state with the parameterization values. Note
that we provide the additional seed
argument which is passed through
to net_init_aux
.
init_key_state = state_model.create_initial_state(jax.random.key(0), seed=10)
init_key_state
AB3State(
t=f32[],
tc=u32[],
state=ParameterizedModelState(
model_state=PseudoSpectralState(qh=c128[2,32,17]),
param_aux=NoStepValue(value=(key<fry>[], (f32[2,32,32], f32[2,32,32]))),
),
)
Notice that the param_aux
value wrapped inside the NoStepValue
object is no longer None
. The key
array is visible,
along with two float32
arrays for past parameterization outputs.
Finally, as before we can roll out a trajectory:
state_traj = roll_out_state(init_key_state, state_model, num_steps=10)
state_traj
AB3State(
t=f32[10],
tc=u32[10],
state=ParameterizedModelState(
model_state=PseudoSpectralState(qh=c128[10,2,32,17]),
param_aux=NoStepValue(value=(key<fry>[10], (f32[10,2,32,32], f32[10,2,32,32]))),
),
)
Note that the parameterization states also have time dimensions.
However, their values are not time-stepped by the AB3Stepper
, but
instead the updated values produced by net_key_parameterization
are
used directly and are left untouched by the time-stepper.