🔪 Flax - 锋利的细节 🔪#
Flax 公开了 JAX 的全部功能。就像使用 JAX 时一样,在使用 Flax 时也可能遇到某些“锋利的细节”。这份不断发展的文档旨在帮助您解决这些问题。
首先,安装和/或更新 Flax
! pip install -qq flax
🔪 flax.linen.Dropout
层和随机性#
太长不看#
在处理具有 dropout 的模型(从 Flax Module
继承)时,仅在正向传递期间添加 'dropout'
PRNG 密钥。
从
jax.random.split()
开始,显式创建用于'params'
和'dropout'
的 PRNG 密钥。将
flax.linen.Dropout
层添加到您的模型中(从 FlaxModule
继承)。在初始化模型(
flax.linen.init()
)时,无需传入额外的'dropout'
PRNG 密钥——只需像“更简单”的模型一样传入'params'
密钥。在使用
flax.linen.apply()
进行正向传递时,传入rngs={'dropout': dropout_key}
。
查看下面的完整示例。
为什么这样有效#
在内部,
flax.linen.Dropout
使用flax.linen.Module.make_rng
为 dropout 创建一个密钥(查看 源代码)。每次调用
make_rng
时(在本例中,它在Dropout
中隐式完成),您都会从主/根 PRNG 密钥中获得一个新的 PRNG 密钥拆分。make_rng
仍然保证完全可重复。
背景#
随机正则化技术 dropout 会随机删除网络中的隐藏和可见单元。Dropout 是一种随机操作,需要 PRNG 状态,而 Flax(像 JAX 一样)使用 Threefry PRNG,它是可拆分的。
注意:请记住,JAX 有一种显式的方式为您提供 PRNG 密钥:您可以使用
key = jax.random.key(seed=0)
将主 PRNG 状态分解成多个新的 PRNG 密钥,使用key, subkey = jax.random.split(key)
。请在 🔪 JAX - 锋利的细节 🔪 随机性和 PRNG 密钥 中回顾一下。
Flax 通过 Flax Module
的 flax.linen.Module.make_rng
辅助函数提供了一种隐式的处理 PRNG 密钥流的方式。它允许 Flax Module
(或其子 Module
)中的代码“提取 PRNG 密钥”。make_rng
保证每次调用时都提供一个唯一的密钥。有关更多详细信息,请参阅 RNG 指南。
注意:请记住,
flax.linen.Module
是所有神经网络模块的基类。所有层和模型都从它继承。
示例#
请记住,每个 Flax PRNG 流都有一个名称。下面的示例使用 'params'
流来初始化参数,以及 'dropout'
流。提供给 flax.linen.init()
的 PRNG 密钥是为 'params'
PRNG 密钥流设置种子的密钥。要在正向传递(使用 dropout)期间提取 PRNG 密钥,请在调用 Module.apply()
时提供一个 PRNG 密钥来为该流('dropout'
)设置种子。
# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Randomness.
seed = 0
root_key = jax.random.key(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
# A simple network.
class MyModel(nn.Module):
num_neurons: int
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a rate of 50% .
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
return x
# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)
x = jax.random.uniform(key=main_key, shape=(3, 4, 4))
# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.)
variables = my_model.init(params_key, x)
# Perform the forward pass with `flax.linen.apply()`.
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})
现实生活中的例子