随机性#
与 Haiku 和 Flax Linen 等系统相比,Flax NNX 中的随机状态处理得到了根本简化,因为 Flax NNX *将随机状态定义为对象状态*。本质上,这意味着在 Flax NNX 中,随机状态:1) 只是另一种类型的状态;2) 存储在 nnx.Variable
中;3) 由模型本身持有。
Flax NNX 伪随机数生成器 (PRNG) 系统具有以下主要特征
它是显式的。
它是基于顺序的。
它使用动态计数器。
这与 Flax Linen 的 PRNG 系统 有些不同,后者是基于 (路径 + 顺序)
的,并使用静态计数器。
注意:要了解有关 JAX 中随机数生成、
jax.random
API 和 PRNG 生成序列的更多信息,请查看此 JAX PRNG 教程。
让我们从一些必要的导入开始
from flax import nnx
import jax
from jax import random, numpy as jnp
Rngs
、RngStream
和 RngState
#
在 Flax NNX 中,nnx.Rngs
类型是管理随机状态的主要便捷 API。 遵循 Flax Linen 的脚步,nnx.Rngs
能够创建多个命名的 PRNG 密钥 流,每个流都有自己的状态,以便在 JAX 转换(transforms)的上下文中对随机性进行严格控制。
以下是 Flax NNX 中主要的 PRNG 相关类型
nnx.Rngs
:主要用户界面。它定义了一组命名的nnx.RngStream
对象。nnx.RngStream
:一个可以生成 PRNG 密钥流的对象。它在一个nnx.RngKey
和nnx.RngCount
nnx.Variable
中分别保存一个根key
和一个count
。当生成新密钥时,计数会递增。nnx.RngState
:所有 RNG 相关状态的基本类型。nnx.RngKey
:用于保存 PRNG 密钥的 NNX 变量类型。它包含一个tag
属性,其中包含 PRNG 密钥流的名称。nnx.RngCount
:用于保存 PRNG 计数的 NNX 变量类型。它包含一个tag
属性,其中包含 PRNG 密钥流名称。
要创建 nnx.Rngs
对象,您可以简单地将整数种子或 jax.random.key
实例传递给构造函数中您选择的任何关键字参数。
这是一个例子
rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)
请注意,key
和 count
nnx.Variable
在 tag
属性中包含 PRNG 密钥流名称。这主要用于过滤,我们稍后会看到。
要生成新密钥,您可以访问其中一个流,并使用其 __call__
方法,不带任何参数。这将通过使用 random.fold_in
和当前 key
和 count
返回一个新密钥。然后,count
会递增,以便后续调用将返回新密钥。
params_key = rngs.params()
dropout_key = rngs.dropout()
nnx.display(rngs)
请注意,当生成新的 PRNG 密钥时,key
属性不会更改。
标准 PRNG 密钥流名称#
Flax NNX 的内置层仅使用两个标准 PRNG 密钥流名称,如下表所示
PRNG 密钥流名称 |
描述 |
---|---|
|
用于参数初始化 |
|
|
params
由大多数标准层(如nnx.Linear
、nnx.Conv
、nnx.MultiHeadAttention
等)在构造过程中使用,以初始化其参数。dropout
由nnx.Dropout
和nnx.MultiHeadAttention
用于生成 dropout 掩码。
下面是一个简单示例,演示了一个使用 params
和 dropout
PRNG 密钥流的模型
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.drop(self.linear(x)))
model = Model(nnx.Rngs(params=0, dropout=1))
y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)
默认 PRNG 密钥流#
拥有命名流的一个缺点是,用户在创建 nnx.Rngs
对象时需要知道模型将使用的所有可能名称。虽然可以通过一些文档来解决这个问题,但 Flax NNX 提供了一个 default
流,当找不到流时可以用作回退。要使用默认 PRNG 密钥流,您可以简单地将整数种子或 jax.random.key
作为第一个位置参数传递。
rngs = nnx.Rngs(0, params=1)
key1 = rngs.params() # Call params.
key2 = rngs.dropout() # Fallback to the default stream.
key3 = rngs() # Call the default stream directly.
# Test with the `Model` that uses `params` and `dropout`.
model = Model(rngs)
y = model(jnp.ones((1, 20)))
nnx.display(rngs)
如上所示,还可以通过调用 nnx.Rngs
对象本身来生成来自 default
流的 PRNG 密钥。
注意
对于大型项目,建议使用命名流以避免潜在的冲突。对于小型项目或快速原型设计,仅使用default
流是一个不错的选择。
过滤随机状态#
可以使用过滤器来操作随机状态,就像操作任何其他类型的状态一样。可以使用类型(nnx.RngState
、nnx.RngKey
、nnx.RngCount
)或与流名称对应的字符串进行过滤(请参阅Flax NNX Filter
DSL)。以下是一个使用 nnx.state
和各种过滤器来选择 Model
内部 Rngs
的不同子状态的示例
model = Model(nnx.Rngs(params=0, dropout=1))
rng_state = nnx.state(model, nnx.RngState) # All random states.
key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.
count_state = nnx.state(model, nnx.RngCount) # Only counts.
rng_params_state = nnx.state(model, 'params') # Only `params`.
rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.
nnx.display(params_key_state)
重置种子#
在 Haiku 和 Flax Linen 中,每次调用模型之前,都会显式地将随机状态传递给 Module.apply
。这使得在需要时(例如,为了可重复性)可以轻松控制模型的随机性。
在 Flax NNX 中,有两种方法可以实现这一点
通过手动通过
__call__
堆栈传递nnx.Rngs
对象。如果你想对随机状态进行严格控制,像nnx.Dropout
和nnx.MultiHeadAttention
这样的标准层会接受rngs
参数。通过使用
nnx.reseed
将模型的随机状态设置为特定的配置。此选项侵入性较小,即使模型并非旨在支持手动控制随机状态也可以使用。
nnx.reseed
是一个接受任意图节点(包括 pytrees 的 nnx.Module
)的函数,以及一些包含 nnx.RngStream
的新种子或键值的关键字参数(由参数名称指定)。nnx.reseed
将遍历图并更新匹配的 nnx.RngStream
的随机状态,这包括将 key
设置为可能的新值,并将 count
重置为零。
以下是如何使用 nnx.reseed
重置 nnx.Dropout
层的随机状态,并验证计算与第一次调用模型时相同的示例
model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))
y1 = model(x)
y2 = model(x)
nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)
assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3) # same
拆分 PRNG 键#
当与 Flax NNX 转换(如 nnx.vmap
或 nnx.pmap
)交互时,通常需要拆分随机状态,以便每个副本都有自己唯一的状态。这可以通过两种方式完成
在将键传递给
nnx.Rngs
流之一之前,手动拆分键;或通过使用
nnx.split_rngs
装饰器,它会自动拆分函数输入中找到的任何nnx.RngStream
的随机状态,并在函数调用结束后自动“降低”它们。
使用 nnx.split_rngs
更方便,因为它与 Flax NNX 转换配合良好,因此这里有一个示例
rngs = nnx.Rngs(params=0, dropout=1)
@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
print('Inside:')
# rngs.dropout() # ValueError: fold_in accepts a single key...
nnx.display(rngs)
f(rngs)
print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:
注意:
nnx.split_rngs
允许将 NNXFilter
传递给only
关键字参数,以便选择函数内部应拆分的nnx.RngStream
。在这种情况下,你只需要拆分dropout
PRNG 键流。
转换#
如前所述,在 Flax NNX 中,随机状态只是另一种类型的状态。这意味着,当涉及到 Flax NNX 转换时,它没有什么特殊之处,这意味着你应该能够使用每个转换的 Flax NNX 状态处理 API 来获得你想要的结果。
在本节中,你将通过两个在 Flax NNX 转换中使用随机状态的示例 - 一个使用 nnx.pmap
,你将学习如何拆分 PRNG 状态,另一个使用 nnx.scan
,你将冻结 PRNG 状态。
数据并行 dropout#
在第一个示例中,你将探索如何使用 nnx.pmap
在数据并行上下文中调用 nnx.Model
。
由于
nnx.Model
使用nnx.Dropout
,你需要拆分dropout
的随机状态,以确保每个副本获得不同的 dropout 掩码。nnx.StateAxes
被传递给in_axes
,以指定model
的dropout
PRNG 键流将在轴0
上并行化,其余状态将被复制。nnx.split_rngs
用于将dropout
PRNG 键流的键拆分为 N 个唯一的键,每个副本一个。
model = Model(nnx.Rngs(params=0, dropout=1))
num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})
@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
return model(x)
y = forward(model, x)
print(y.shape)
(1, 16, 10)
循环 dropout#
接下来,让我们探讨如何实现一个使用循环 dropout 的 RNNCell
。为此
首先,你将创建一个
nnx.Dropout
层,该层将从自定义的recurrent_dropout
流中采样 PRNG 键。你将对
RNNCell
的隐藏状态h
应用 dropout (drop
)。然后,定义一个
initial_state
函数来创建RNNCell
的初始状态。最后,实例化
RNNCell
。
class Count(nnx.Variable): pass
class RNNCell(nnx.Module):
def __init__(self, din, dout, rngs):
self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
self.dout = dout
self.count = Count(jnp.array(0, jnp.uint32))
def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
h = self.drop(h) # Recurrent dropout.
y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
self.count += 1
return y, y
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.dout))
cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))
接下来,你将使用 nnx.scan
对 unroll
函数进行扫描,以实现 rnn_forward
操作
循环 dropout 的关键在于在所有时间步上应用相同的 dropout 掩码。因此,为了实现这一点,你将
nnx.StateAxes
传递给nnx.scan
的in_axes
,指定将广播cell
的recurrent_dropout
PRNG 流,并且其余的RNNCell
状态将被传递。此外,隐藏状态
h
将是nnx.scan
的Carry
变量,并且序列x
将在其轴1
上进行scan
操作。
@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
h = cell.initial_state(batch_size=x.shape[0])
# Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.
state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
@nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
h, y = cell(h, x)
return h, y
h, y = unroll(cell, h, x)
return y
x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)
print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)