丢弃#
本指南概述了如何使用 丢弃 应用 flax.linen.Dropout()
.
丢弃是一种随机正则化技术,它随机删除网络中的隐藏和可见单元。
在本指南中,您可以比较有无 Flax Dropout
的代码示例。
拆分 PRNG 密钥#
由于丢弃是一种随机操作,因此它需要一个伪随机数生成器 (PRNG) 状态。Flax 使用 JAX 的(可拆分的)PRNG 密钥,它具有神经网络的许多理想属性。要了解更多信息,请参阅 JAX 中的伪随机数教程.
注意:请记住,JAX 有一种明确的方式来为您提供 PRNG 密钥:您可以将主 PRNG 状态(例如 key = jax.random.key(seed=0)
)使用 key, subkey = jax.random.split(key)
拆分为多个新的 PRNG 密钥。您可以在 🔪 JAX - 核心部分 🔪 随机性和 PRNG 密钥 中复习。
首先使用 jax.random.split() 将 PRNG 密钥拆分为三个密钥,包括一个用于 Flax Linen Dropout
的密钥。
root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
注意:在 Flax 中,您提供带有名称的PRNG 流,以便您可以在以后的 flax.linen.Module()
中使用它们。例如,您传递流 'params'
用于初始化参数,以及 'dropout'
用于应用 flax.linen.Dropout()
.
使用 Dropout
定义您的模型#
要创建带有丢弃的模型
子类化
flax.linen.Module()
,然后使用flax.linen.Dropout()
添加一个丢弃层。请记住,flax.linen.Module()
是 所有神经网络模块的基类,所有层和模型都从它派生子类。在
flax.linen.Dropout()
中,deterministic
参数必须作为关键字参数传递,具体方式如下:在构造
flax.linen.Module()
时;或在构造的
Module
上调用flax.linen.init()
或flax.linen.apply()
时。(有关更多详细信息,请参阅flax.linen.module.merge_param()
。)
因为
deterministic
是一个布尔值如果它设置为
False
,则输入将被屏蔽(即设置为零),其概率由rate
设置。并且其余输入将按1 / (1 - rate)
比例进行缩放,这确保了输入的均值得到保留。如果它设置为
True
,则不会应用任何遮罩(丢弃将被关闭),并且输入将按原样返回。
一个常见的模式是在父 Flax Module
中接受一个 training
(或 train
)参数(一个布尔值),并使用它来启用或禁用丢弃(如本指南后面部分所述)。在其他机器学习框架(如 PyTorch 或 TensorFlow(Keras))中,这是通过可变状态或调用标志指定的(例如,在 torch.nn.Module.eval 或 tf.keras.Model
中通过设置 training 标志)。
注意:Flax 提供了一种隐式的方式通过 Flax flax.linen.Module()
的 flax.linen.Module.make_rng()
方法来处理 PRNG 密钥流。这使您能够从 PRNG 流中在 Flax 模块(或其子模块)内部拆分一个新的 PRNG 密钥。 make_rng
方法保证每次调用时都提供一个唯一的密钥。在内部,flax.linen.Dropout()
使用 flax.linen.Module.make_rng()
为丢弃创建一个密钥。您可以查看 源代码。简而言之,flax.linen.Module.make_rng()
*保证完全可重复性*。
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
return x
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a `rate` of 50%.
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
初始化模型#
在创建模型后
实例化模型。
然后,在
flax.linen.init()
调用中,设置training=False
。最后,从 变量字典 中提取
params
。
在这里,有无 Flax Dropout
的代码之间的主要区别在于,如果您需要启用丢弃,则必须提供 training
(或 train
)参数。
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x)
params = variables['params']
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']
在训练期间执行前向传递#
当使用 flax.linen.apply()
运行您的模型时
将
training=True
传递给flax.linen.apply()
.然后,在正向传播期间(使用 dropout)绘制 PRNG 密钥时,在调用
flax.linen.apply()
时,提供一个 PRNG 密钥来为'dropout'
流播种。
# No need to pass the `training` and `rngs` flags.
y = my_model.apply({'params': params}, x)
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})
这里,没有 Flax Dropout
和有 Dropout
的代码之间的主要区别在于,如果您需要启用 dropout,则必须提供 training
(或 train
)和 rngs
参数。
在评估期间,使用上面没有启用 dropout 的代码(这意味着您也不必传递 RNG)。
TrainState
和训练步骤#
本节说明如果启用了 dropout,如何在训练步骤函数中修改代码。
注意:请记住,Flax 有一种常见的模式,您可以在其中创建一个数据类来表示整个训练状态,包括参数和优化器状态。然后,您可以将单个参数 state: TrainState
传递给训练步骤函数。请参考 flax.training.train_state.TrainState()
API 文档以了解更多信息。
首先,将一个
key
字段添加到自定义flax.training.train_state.TrainState()
类中。然后,将
key
值(在本例中为dropout_key
)传递给train_state.TrainState.create()
方法。
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=my_model.apply,
params=params,
tx=optax.adam(1e-3)
)
from flax.training import train_state
class TrainState(train_state.TrainState):
key: jax.Array
state = TrainState.create(
apply_fn=my_model.apply,
params=params,
key=dropout_key,
tx=optax.adam(1e-3)
)
接下来,在 Flax 训练步骤函数
train_step
中,从dropout_key
生成一个新的 PRNG 密钥,以便在每一步应用 dropout。这可以通过以下方法之一完成使用
jax.random.fold_in()
通常更快。当您使用jax.random.split()
时,您会拆分一个可以重复使用的 PRNG 密钥。但是,使用jax.random.fold_in()
确保:1) 融合唯一数据;以及 2) 可能导致更长的 PRNG 流序列。最后,在执行正向传播时,将新的 PRNG 密钥作为额外参数传递给
state.apply_fn()
。
@jax.jit
def train_step(state: train_state.TrainState, batch):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
training=True,
rngs={'dropout': dropout_train_key}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
带有 dropout 的 Flax 示例#
在 WMT 机器翻译数据集上训练的 基于 Transformer 的模型。此示例使用 dropout 和注意力 dropout。
在 文本分类 上下文中将词 dropout 应用于一批输入 ID。此示例使用自定义
flax.linen.Dropout()
层。