批次规范化#

在本指南中,您将学习如何使用 flax.linen.BatchNorm 应用 批次规范化

批次规范化是一种正则化技术,用于加速训练并提高收敛速度。在训练期间,它计算特征维度的运行平均值。这添加了一种新的不可微分状态形式,必须适当地处理。

在本指南中,您将能够比较有无 Flax BatchNorm 的代码示例。

使用 BatchNorm 定义模型#

在 Flax 中,BatchNorm 是一个 flax.linen.Module,它在训练和推理之间表现出不同的运行时行为。您可以通过 use_running_average 参数显式指定它,如下所示。

一种常见模式是在父 Flax Module 中接受一个 traintraining)参数,并使用它来定义 BatchNormuse_running_average 参数。

注意:在其他机器学习框架(如 PyTorch 或 TensorFlow(Keras))中,这通过可变状态或调用标志来指定(例如,在 torch.nn.Module.evaltf.keras.Model 中,通过设置 training 标志)。

class MLP(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=4)(x)

    x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x
class MLP(nn.Module):
  @nn.compact
  def __call__(self, x, train: bool):
    x = nn.Dense(features=4)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x

创建模型后,通过调用 flax.linen.init() 来初始化模型,以获取 variables 结构。这里,无 BatchNorm 和有 BatchNorm 的代码之间的主要区别在于必须提供 train 参数。

batch_stats 集合#

除了 params 集合之外,BatchNorm 还添加了一个 batch_stats 集合,其中包含批次统计信息的运行平均值。

注意:您可以在 flax.linen variables API 文档中了解有关此的更多信息。

必须从 variables 中提取 batch_stats 集合以供日后使用。

mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x)
params = variables['params']


jax.tree_util.tree_map(jnp.shape, variables)
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']

jax.tree_util.tree_map(jnp.shape, variables)

Flax BatchNorm 共添加了 4 个变量:meanvar 位于 batch_stats 集合中,而 scalebias 位于 params 集合中。

FrozenDict({
  'params': {
    'Dense_0': {
        'bias': (4,),
        'kernel': (3, 4),
    },
    'Dense_1': {
        'bias': (1,),
        'kernel': (4, 1),
    },
  },
})
FrozenDict({
  'batch_stats': {
    'BatchNorm_0': {
        'mean': (4,),
        'var': (4,),
    },
  },
  'params': {
    'BatchNorm_0': {
        'bias': (4,),
        'scale': (4,),
    },
    'Dense_0': {
        'bias': (4,),
        'kernel': (3, 4),
    },
    'Dense_1': {
        'bias': (1,),
        'kernel': (4, 1),
    },
  },
})

修改 flax.linen.apply#

当使用 flax.linen.apply 使用 train==True 参数运行模型时(即,在调用 BatchNorm 时,use_running_average==False),您需要考虑以下内容

  • batch_stats 必须作为输入变量传递。

  • batch_stats 集合需要通过设置 mutable=['batch_stats'] 来标记为可变。

  • 修改后的变量作为第二个输出返回。必须从此处提取更新的 batch_stats

y = mlp.apply(
  {'params': params},
  x,
)
...
y, updates = mlp.apply(
  {'params': params, 'batch_stats': batch_stats},
  x,
  train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']

训练和评估#

在将使用 BatchNorm 的模型集成到训练循环中时,主要挑战是处理额外的 batch_stats 状态。为此,您需要

from flax.training import train_state


state = train_state.TrainState.create(
  apply_fn=mlp.apply,
  params=params,

  tx=optax.adam(1e-3),
)
from flax.training import train_state

class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=mlp.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3),
)

此外,更新您的 train_step 函数以反映这些更改

  • 将所有新参数传递给 flax.linen.apply(如前所述)。

  • batch_statsupdates 必须从 loss_fn 中传播出去。

  • 必须更新来自 TrainStatebatch_stats

@jax.jit
def train_step(state: train_state.TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']).mean()
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)

  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics
@jax.jit
def train_step(state: TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': state.batch_stats},
      x=batch['image'], train=True, mutable=['batch_stats'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']).mean()
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics

eval_step 非常简单。由于 batch_stats 不可变,因此无需传播任何更新。确保将 batch_stats 传递给 flax.linen.apply,并将 train 参数设置为 False

@jax.jit
def eval_step(state: train_state.TrainState, batch):
  """Train for a single step."""
  logits = state.apply_fn(
    {'params': params},
    x=batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']).mean()
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics
@jax.jit
def eval_step(state: TrainState, batch):
  """Evaluate for a single step."""
  logits = state.apply_fn(
    {'params': state.params, 'batch_stats': state.batch_stats},
    x=batch['image'], train=False)
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']).mean()
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics