🔪 Flax - 锋利的细节 🔪#

Flax 公开了 JAX 的全部功能。就像使用 JAX 时一样,在使用 Flax 时也可能遇到某些“锋利的细节”。这份不断发展的文档旨在帮助您解决这些问题。

首先,安装和/或更新 Flax

! pip install -qq flax

🔪 flax.linen.Dropout 层和随机性#

太长不看#

在处理具有 dropout 的模型(从 Flax Module 继承)时,仅在正向传递期间添加 'dropout' PRNG 密钥。

  1. jax.random.split() 开始,显式创建用于 'params''dropout' 的 PRNG 密钥。

  2. flax.linen.Dropout 层添加到您的模型中(从 Flax Module 继承)。

  3. 在初始化模型(flax.linen.init())时,无需传入额外的 'dropout' PRNG 密钥——只需像“更简单”的模型一样传入 'params' 密钥。

  4. 在使用 flax.linen.apply() 进行正向传递时,传入 rngs={'dropout': dropout_key}

查看下面的完整示例。

为什么这样有效#

  • 在内部,flax.linen.Dropout 使用 flax.linen.Module.make_rng 为 dropout 创建一个密钥(查看 源代码)。

  • 每次调用 make_rng 时(在本例中,它在 Dropout 中隐式完成),您都会从主/根 PRNG 密钥中获得一个新的 PRNG 密钥拆分。

  • make_rng 仍然保证完全可重复

背景#

随机正则化技术 dropout 会随机删除网络中的隐藏和可见单元。Dropout 是一种随机操作,需要 PRNG 状态,而 Flax(像 JAX 一样)使用 Threefry PRNG,它是可拆分的。

注意:请记住,JAX 有一种显式的方式为您提供 PRNG 密钥:您可以使用 key = jax.random.key(seed=0) 将主 PRNG 状态分解成多个新的 PRNG 密钥,使用 key, subkey = jax.random.split(key)。请在 🔪 JAX - 锋利的细节 🔪 随机性和 PRNG 密钥 中回顾一下。

Flax 通过 Flax Moduleflax.linen.Module.make_rng 辅助函数提供了一种隐式的处理 PRNG 密钥流的方式。它允许 Flax Module(或其子 Module)中的代码“提取 PRNG 密钥”。make_rng 保证每次调用时都提供一个唯一的密钥。有关更多详细信息,请参阅 RNG 指南

注意:请记住,flax.linen.Module 是所有神经网络模块的基类。所有层和模型都从它继承。

示例#

请记住,每个 Flax PRNG 流都有一个名称。下面的示例使用 'params' 流来初始化参数,以及 'dropout' 流。提供给 flax.linen.init() 的 PRNG 密钥是为 'params' PRNG 密钥流设置种子的密钥。要在正向传递(使用 dropout)期间提取 PRNG 密钥,请在调用 Module.apply() 时提供一个 PRNG 密钥来为该流('dropout')设置种子。

# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Randomness.
seed = 0
root_key = jax.random.key(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
    return x

# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)

x = jax.random.uniform(key=main_key, shape=(3, 4, 4))

# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.) 
variables = my_model.init(params_key, x)

# Perform the forward pass with `flax.linen.apply()`.
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})

现实生活中的例子