管理参数和状态

管理参数和状态#

我们将向您展示如何…

  • 管理从初始化到更新的变量。

  • 拆分和重新组装参数和状态。

  • 使用 vmap 具有依赖于批次维度的状态。

class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats', 'mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      mean.value = (self.momentum * mean.value +
                    (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
    return mean.value + bias

此示例模型是一个最小示例,其中包含参数(使用 self.param 声明)和状态变量(使用 self.variable 声明)。

此处的棘手部分是我们需要将要优化的状态变量和参数分开。

首先,我们按如下方式定义 update_step(使用一个应该用您自己的损失替换的虚拟损失)

def update_step(apply_fn, x, opt_state, params, state):
  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum() # Replace with your loss here.
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state

然后我们可以编写实际的训练代码。

model = BiasAdderWithRunningMean()
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state)

vmap 在批次维度上#

当使用 vmap 并管理依赖于批次维度的状态时,例如当使用 BatchNorm 时,上述设置必须稍作修改。这是因为任何状态依赖于批次维度的层都不能严格向量化。在 BatchNorm 的情况下,必须使用 lax.pmean() 在批次维度上对统计数据进行平均,以便状态在批次中的每个项目中保持同步。

这需要两个小的更改。首先,我们需要在模型定义中命名批次轴。在这里,这是通过指定 BatchNormaxis_name 参数来完成的。在您自己的代码中,这可能需要直接指定 lax.pmean()axis_name 参数。

class MLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x, train=False):
    norm = partial(
        nn.BatchNorm,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        axis_name="batch", # Name batch dim
    )

    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    y = nn.Dense(self.out_size)(x)

    return y

其次,我们需要在训练代码中调用 vmap 时指定相同的名称

def update_step(apply_fn, x_batch, y_batch, opt_state, params, state):

  def batch_loss(params):
    def loss_fn(x, y):
      pred, updated_state = apply_fn(
        {'params': params, **state},
        x, mutable=list(state.keys())
      )
      return (pred - y) ** 2, updated_state

    loss, updated_state = jax.vmap(
      loss_fn, out_axes=(0, None),  # Do not vmap `updated_state`.
      axis_name='batch'  # Name batch dim
    )(x_batch, y_batch)  # vmap only `x`, `y`, but not `state`.
    return jnp.mean(loss), updated_state

  (loss, updated_state), grads = jax.value_and_grad(
    batch_loss, has_aux=True
  )(params)

  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state, loss

请注意,我们还需要指定模型状态没有批次维度。现在我们可以训练模型了

model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state, loss = update_step(
      model.apply, X, Y, opt_state, params, state)