批次规范化#
在本指南中,您将学习如何使用 flax.linen.BatchNorm
应用 批次规范化。
批次规范化是一种正则化技术,用于加速训练并提高收敛速度。在训练期间,它计算特征维度的运行平均值。这添加了一种新的不可微分状态形式,必须适当地处理。
在本指南中,您将能够比较有无 Flax BatchNorm
的代码示例。
使用 BatchNorm
定义模型#
在 Flax 中,BatchNorm
是一个 flax.linen.Module
,它在训练和推理之间表现出不同的运行时行为。您可以通过 use_running_average
参数显式指定它,如下所示。
一种常见模式是在父 Flax Module
中接受一个 train
(training
)参数,并使用它来定义 BatchNorm
的 use_running_average
参数。
注意:在其他机器学习框架(如 PyTorch 或 TensorFlow(Keras))中,这通过可变状态或调用标志来指定(例如,在 torch.nn.Module.eval 或 tf.keras.Model
中,通过设置 training 标志)。
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=4)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
class MLP(nn.Module):
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(features=4)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
创建模型后,通过调用 flax.linen.init()
来初始化模型,以获取 variables
结构。这里,无 BatchNorm
和有 BatchNorm
的代码之间的主要区别在于必须提供 train
参数。
batch_stats
集合#
除了 params
集合之外,BatchNorm
还添加了一个 batch_stats
集合,其中包含批次统计信息的运行平均值。
注意:您可以在 flax.linen
variables API 文档中了解有关此的更多信息。
必须从 variables
中提取 batch_stats
集合以供日后使用。
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x)
params = variables['params']
jax.tree_util.tree_map(jnp.shape, variables)
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']
jax.tree_util.tree_map(jnp.shape, variables)
Flax BatchNorm
共添加了 4 个变量:mean
和 var
位于 batch_stats
集合中,而 scale
和 bias
位于 params
集合中。
FrozenDict({
'params': {
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
FrozenDict({
'batch_stats': {
'BatchNorm_0': {
'mean': (4,),
'var': (4,),
},
},
'params': {
'BatchNorm_0': {
'bias': (4,),
'scale': (4,),
},
'Dense_0': {
'bias': (4,),
'kernel': (3, 4),
},
'Dense_1': {
'bias': (1,),
'kernel': (4, 1),
},
},
})
修改 flax.linen.apply
#
当使用 flax.linen.apply
使用 train==True
参数运行模型时(即,在调用 BatchNorm
时,use_running_average==False
),您需要考虑以下内容
batch_stats
必须作为输入变量传递。batch_stats
集合需要通过设置mutable=['batch_stats']
来标记为可变。修改后的变量作为第二个输出返回。必须从此处提取更新的
batch_stats
。
y = mlp.apply(
{'params': params},
x,
)
...
y, updates = mlp.apply(
{'params': params, 'batch_stats': batch_stats},
x,
train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']
训练和评估#
在将使用 BatchNorm
的模型集成到训练循环中时,主要挑战是处理额外的 batch_stats
状态。为此,您需要
向自定义
flax.training.train_state.TrainState
类添加一个batch_stats
字段。将
batch_stats
值传递给train_state.TrainState.create
方法。
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=mlp.apply,
params=params,
tx=optax.adam(1e-3),
)
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn=mlp.apply,
params=params,
batch_stats=batch_stats,
tx=optax.adam(1e-3),
)
此外,更新您的 train_step
函数以反映这些更改
将所有新参数传递给
flax.linen.apply
(如前所述)。对
batch_stats
的updates
必须从loss_fn
中传播出去。必须更新来自
TrainState
的batch_stats
。
@jax.jit
def train_step(state: train_state.TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
@jax.jit
def train_step(state: TrainState, batch):
"""Train for a single step."""
def loss_fn(params):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
x=batch['image'], train=True, mutable=['batch_stats'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
return loss, (logits, updates)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, updates)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=updates['batch_stats'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
eval_step
非常简单。由于 batch_stats
不可变,因此无需传播任何更新。确保将 batch_stats
传递给 flax.linen.apply
,并将 train
参数设置为 False
@jax.jit
def eval_step(state: train_state.TrainState, batch):
"""Train for a single step."""
logits = state.apply_fn(
{'params': params},
x=batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics
@jax.jit
def eval_step(state: TrainState, batch):
"""Evaluate for a single step."""
logits = state.apply_fn(
{'params': state.params, 'batch_stats': state.batch_stats},
x=batch['image'], train=False)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics