管理参数和状态#
我们将向您展示如何…
管理从初始化到更新的变量。
拆分和重新组装参数和状态。
使用
vmap
具有依赖于批次维度的状态。
class BiasAdderWithRunningMean(nn.Module):
momentum: float = 0.9
@nn.compact
def __call__(self, x):
is_initialized = self.has_variable('batch_stats', 'mean')
mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
mean.value = (self.momentum * mean.value +
(1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
return mean.value + bias
此示例模型是一个最小示例,其中包含参数(使用 self.param
声明)和状态变量(使用 self.variable
声明)。
此处的棘手部分是我们需要将要优化的状态变量和参数分开。
首先,我们按如下方式定义 update_step
(使用一个应该用您自己的损失替换的虚拟损失)
def update_step(apply_fn, x, opt_state, params, state):
def loss(params):
y, updated_state = apply_fn({'params': params, **state},
x, mutable=list(state.keys()))
l = ((x - y) ** 2).sum() # Replace with your loss here.
return l, updated_state
(l, updated_state), grads = jax.value_and_grad(
loss, has_aux=True)(params)
updates, opt_state = tx.update(grads, opt_state) # Defined below.
params = optax.apply_updates(params, updates)
return opt_state, params, updated_state
然后我们可以编写实际的训练代码。
model = BiasAdderWithRunningMean()
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
for _ in range(num_epochs):
opt_state, params, state = update_step(
model.apply, dummy_input, opt_state, params, state)
vmap
在批次维度上#
当使用 vmap
并管理依赖于批次维度的状态时,例如当使用 BatchNorm
时,上述设置必须稍作修改。这是因为任何状态依赖于批次维度的层都不能严格向量化。在 BatchNorm
的情况下,必须使用 lax.pmean()
在批次维度上对统计数据进行平均,以便状态在批次中的每个项目中保持同步。
这需要两个小的更改。首先,我们需要在模型定义中命名批次轴。在这里,这是通过指定 BatchNorm
的 axis_name
参数来完成的。在您自己的代码中,这可能需要直接指定 lax.pmean()
的 axis_name
参数。
class MLP(nn.Module):
hidden_size: int
out_size: int
@nn.compact
def __call__(self, x, train=False):
norm = partial(
nn.BatchNorm,
use_running_average=not train,
momentum=0.9,
epsilon=1e-5,
axis_name="batch", # Name batch dim
)
x = nn.Dense(self.hidden_size)(x)
x = norm()(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_size)(x)
x = norm()(x)
x = nn.relu(x)
y = nn.Dense(self.out_size)(x)
return y
其次,我们需要在训练代码中调用 vmap
时指定相同的名称
def update_step(apply_fn, x_batch, y_batch, opt_state, params, state):
def batch_loss(params):
def loss_fn(x, y):
pred, updated_state = apply_fn(
{'params': params, **state},
x, mutable=list(state.keys())
)
return (pred - y) ** 2, updated_state
loss, updated_state = jax.vmap(
loss_fn, out_axes=(0, None), # Do not vmap `updated_state`.
axis_name='batch' # Name batch dim
)(x_batch, y_batch) # vmap only `x`, `y`, but not `state`.
return jnp.mean(loss), updated_state
(loss, updated_state), grads = jax.value_and_grad(
batch_loss, has_aux=True
)(params)
updates, opt_state = tx.update(grads, opt_state) # Defined below.
params = optax.apply_updates(params, updates)
return opt_state, params, updated_state, loss
请注意,我们还需要指定模型状态没有批次维度。现在我们可以训练模型了
model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
for _ in range(num_epochs):
opt_state, params, state, loss = update_step(
model.apply, X, Y, opt_state, params, state)