丢弃#

本指南概述了如何使用 丢弃 应用 flax.linen.Dropout().

丢弃是一种随机正则化技术,它随机删除网络中的隐藏和可见单元。

在本指南中,您可以比较有无 Flax Dropout 的代码示例。

拆分 PRNG 密钥#

由于丢弃是一种随机操作,因此它需要一个伪随机数生成器 (PRNG) 状态。Flax 使用 JAX 的(可拆分的)PRNG 密钥,它具有神经网络的许多理想属性。要了解更多信息,请参阅 JAX 中的伪随机数教程.

注意:请记住,JAX 有一种明确的方式来为您提供 PRNG 密钥:您可以将主 PRNG 状态(例如 key = jax.random.key(seed=0))使用 key, subkey = jax.random.split(key) 拆分为多个新的 PRNG 密钥。您可以在 🔪 JAX - 核心部分 🔪 随机性和 PRNG 密钥 中复习。

首先使用 jax.random.split() 将 PRNG 密钥拆分为三个密钥,包括一个用于 Flax Linen Dropout 的密钥。

root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

注意:在 Flax 中,您提供带有名称PRNG 流,以便您可以在以后的 flax.linen.Module() 中使用它们。例如,您传递流 'params' 用于初始化参数,以及 'dropout' 用于应用 flax.linen.Dropout().

使用 Dropout 定义您的模型#

要创建带有丢弃的模型

  • 子类化 flax.linen.Module(),然后使用 flax.linen.Dropout() 添加一个丢弃层。请记住,flax.linen.Module()所有神经网络模块的基类,所有层和模型都从它派生子类。

  • flax.linen.Dropout() 中,deterministic 参数必须作为关键字参数传递,具体方式如下:

  • 因为 deterministic 是一个布尔值

    • 如果它设置为 False,则输入将被屏蔽(即设置为零),其概率由 rate 设置。并且其余输入将按 1 / (1 - rate) 比例进行缩放,这确保了输入的均值得到保留。

    • 如果它设置为 True,则不会应用任何遮罩(丢弃将被关闭),并且输入将按原样返回。

一个常见的模式是在父 Flax Module 中接受一个 training(或 train)参数(一个布尔值),并使用它来启用或禁用丢弃(如本指南后面部分所述)。在其他机器学习框架(如 PyTorch 或 TensorFlow(Keras))中,这是通过可变状态或调用标志指定的(例如,在 torch.nn.Module.evaltf.keras.Model 中通过设置 training 标志)。

注意:Flax 提供了一种隐式的方式通过 Flax flax.linen.Module()flax.linen.Module.make_rng() 方法来处理 PRNG 密钥流。这使您能够从 PRNG 流中在 Flax 模块(或其子模块)内部拆分一个新的 PRNG 密钥。 make_rng 方法保证每次调用时都提供一个唯一的密钥。在内部,flax.linen.Dropout() 使用 flax.linen.Module.make_rng() 为丢弃创建一个密钥。您可以查看 源代码。简而言之,flax.linen.Module.make_rng() *保证完全可重复性*。

class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)

    return x
class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a `rate` of 50%.
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
    return x

初始化模型#

在创建模型后

在这里,有无 Flax Dropout 的代码之间的主要区别在于,如果您需要启用丢弃,则必须提供 training(或 train)参数。

my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))

variables = my_model.init(params_key, x)
params = variables['params']
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']

在训练期间执行前向传递#

当使用 flax.linen.apply() 运行您的模型时

  • training=True 传递给 flax.linen.apply().

  • 然后,在正向传播期间(使用 dropout)绘制 PRNG 密钥时,在调用 flax.linen.apply() 时,提供一个 PRNG 密钥来为 'dropout' 流播种。

# No need to pass the `training` and `rngs` flags.
y = my_model.apply({'params': params}, x)
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})

这里,没有 Flax Dropout 和有 Dropout 的代码之间的主要区别在于,如果您需要启用 dropout,则必须提供 training(或 train)和 rngs 参数。

在评估期间,使用上面没有启用 dropout 的代码(这意味着您也不必传递 RNG)。

TrainState 和训练步骤#

本节说明如果启用了 dropout,如何在训练步骤函数中修改代码。

注意:请记住,Flax 有一种常见的模式,您可以在其中创建一个数据类来表示整个训练状态,包括参数和优化器状态。然后,您可以将单个参数 state: TrainState 传递给训练步骤函数。请参考 flax.training.train_state.TrainState() API 文档以了解更多信息。

  • 首先,将一个 key 字段添加到自定义 flax.training.train_state.TrainState() 类中。

  • 然后,将 key 值(在本例中为 dropout_key)传递给 train_state.TrainState.create() 方法。

from flax.training import train_state

state = train_state.TrainState.create(
  apply_fn=my_model.apply,
  params=params,

  tx=optax.adam(1e-3)
)
from flax.training import train_state

class TrainState(train_state.TrainState):
  key: jax.Array

state = TrainState.create(
  apply_fn=my_model.apply,
  params=params,
  key=dropout_key,
  tx=optax.adam(1e-3)
)

  • 接下来,在 Flax 训练步骤函数 train_step 中,从 dropout_key 生成一个新的 PRNG 密钥,以便在每一步应用 dropout。这可以通过以下方法之一完成

    使用 jax.random.fold_in() 通常更快。当您使用 jax.random.split() 时,您会拆分一个可以重复使用的 PRNG 密钥。但是,使用 jax.random.fold_in() 确保:1) 融合唯一数据;以及 2) 可能导致更长的 PRNG 流序列。

  • 最后,在执行正向传播时,将新的 PRNG 密钥作为额外参数传递给 state.apply_fn()

@jax.jit
def train_step(state: train_state.TrainState, batch):

  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'],


      )
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
  dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'],
      training=True,
      rngs={'dropout': dropout_train_key}
      )
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

带有 dropout 的 Flax 示例#

更多使用 Module make_rng() 的 Flax 示例#