随机性#

与 Haiku 和 Flax Linen 等系统相比,Flax NNX 中的随机状态处理得到了根本简化,因为 Flax NNX *将随机状态定义为对象状态*。本质上,这意味着在 Flax NNX 中,随机状态:1) 只是另一种类型的状态;2) 存储在 nnx.Variable 中;3) 由模型本身持有。

Flax NNX 伪随机数生成器 (PRNG) 系统具有以下主要特征

  • 它是显式的

  • 它是基于顺序的

  • 它使用动态计数器

这与 Flax Linen 的 PRNG 系统 有些不同,后者是基于 (路径 + 顺序) 的,并使用静态计数器。

注意:要了解有关 JAX 中随机数生成、jax.random API 和 PRNG 生成序列的更多信息,请查看此 JAX PRNG 教程

让我们从一些必要的导入开始

from flax import nnx
import jax
from jax import random, numpy as jnp

RngsRngStreamRngState#

在 Flax NNX 中,nnx.Rngs 类型是管理随机状态的主要便捷 API。 遵循 Flax Linen 的脚步,nnx.Rngs 能够创建多个命名的 PRNG 密钥 ,每个流都有自己的状态,以便在 JAX 转换(transforms)的上下文中对随机性进行严格控制。

以下是 Flax NNX 中主要的 PRNG 相关类型

  • nnx.Rngs:主要用户界面。它定义了一组命名的 nnx.RngStream 对象。

  • nnx.RngStream:一个可以生成 PRNG 密钥流的对象。它在一个 nnx.RngKeynnx.RngCount nnx.Variable 中分别保存一个根 key 和一个 count。当生成新密钥时,计数会递增。

  • nnx.RngState:所有 RNG 相关状态的基本类型。

    • nnx.RngKey:用于保存 PRNG 密钥的 NNX 变量类型。它包含一个 tag 属性,其中包含 PRNG 密钥流的名称。

    • nnx.RngCount:用于保存 PRNG 计数的 NNX 变量类型。它包含一个 tag 属性,其中包含 PRNG 密钥流名称。

要创建 nnx.Rngs 对象,您可以简单地将整数种子或 jax.random.key 实例传递给构造函数中您选择的任何关键字参数。

这是一个例子

rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)

请注意,keycount nnx.Variabletag 属性中包含 PRNG 密钥流名称。这主要用于过滤,我们稍后会看到。

要生成新密钥,您可以访问其中一个流,并使用其 __call__ 方法,不带任何参数。这将通过使用 random.fold_in 和当前 keycount 返回一个新密钥。然后,count 会递增,以便后续调用将返回新密钥。

params_key = rngs.params()
dropout_key = rngs.dropout()

nnx.display(rngs)

请注意,当生成新的 PRNG 密钥时,key 属性不会更改。

标准 PRNG 密钥流名称#

Flax NNX 的内置层仅使用两个标准 PRNG 密钥流名称,如下表所示

PRNG 密钥流名称

描述

参数

用于参数初始化

dropout

nnx.Dropout 用于创建 dropout 掩码

  • params 由大多数标准层(如 nnx.Linearnnx.Convnnx.MultiHeadAttention 等)在构造过程中使用,以初始化其参数。

  • dropoutnnx.Dropoutnnx.MultiHeadAttention 用于生成 dropout 掩码。

下面是一个简单示例,演示了一个使用 paramsdropout PRNG 密钥流的模型

class Model(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(20, 10, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.drop(self.linear(x)))

model = Model(nnx.Rngs(params=0, dropout=1))

y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)

默认 PRNG 密钥流#

拥有命名流的一个缺点是,用户在创建 nnx.Rngs 对象时需要知道模型将使用的所有可能名称。虽然可以通过一些文档来解决这个问题,但 Flax NNX 提供了一个 default 流,当找不到流时可以用作回退。要使用默认 PRNG 密钥流,您可以简单地将整数种子或 jax.random.key 作为第一个位置参数传递。

rngs = nnx.Rngs(0, params=1)

key1 = rngs.params() # Call params.
key2 = rngs.dropout() # Fallback to the default stream.
key3 = rngs() # Call the default stream directly.

# Test with the `Model` that uses `params` and `dropout`.
model = Model(rngs)
y = model(jnp.ones((1, 20)))

nnx.display(rngs)

如上所示,还可以通过调用 nnx.Rngs 对象本身来生成来自 default 流的 PRNG 密钥。

注意
对于大型项目,建议使用命名流以避免潜在的冲突。对于小型项目或快速原型设计,仅使用 default 流是一个不错的选择。

过滤随机状态#

可以使用过滤器来操作随机状态,就像操作任何其他类型的状态一样。可以使用类型(nnx.RngStatennx.RngKeynnx.RngCount)或与流名称对应的字符串进行过滤(请参阅Flax NNX Filter DSL)。以下是一个使用 nnx.state 和各种过滤器来选择 Model 内部 Rngs 的不同子状态的示例

model = Model(nnx.Rngs(params=0, dropout=1))

rng_state = nnx.state(model, nnx.RngState) # All random states.
key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.
count_state = nnx.state(model, nnx.RngCount) # Only counts.
rng_params_state = nnx.state(model, 'params') # Only `params`.
rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # `Params` PRNG keys.

nnx.display(params_key_state)

重置种子#

在 Haiku 和 Flax Linen 中,每次调用模型之前,都会显式地将随机状态传递给 Module.apply。这使得在需要时(例如,为了可重复性)可以轻松控制模型的随机性。

在 Flax NNX 中,有两种方法可以实现这一点

  1. 通过手动通过 __call__ 堆栈传递 nnx.Rngs 对象。如果你想对随机状态进行严格控制,像 nnx.Dropoutnnx.MultiHeadAttention 这样的标准层会接受 rngs 参数。

  2. 通过使用 nnx.reseed 将模型的随机状态设置为特定的配置。此选项侵入性较小,即使模型并非旨在支持手动控制随机状态也可以使用。

nnx.reseed 是一个接受任意图节点(包括 pytreesnnx.Module)的函数,以及一些包含 nnx.RngStream 的新种子或键值的关键字参数(由参数名称指定)。nnx.reseed 将遍历图并更新匹配的 nnx.RngStream 的随机状态,这包括将 key 设置为可能的新值,并将 count 重置为零。

以下是如何使用 nnx.reseed 重置 nnx.Dropout 层的随机状态,并验证计算与第一次调用模型时相同的示例

model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))

y1 = model(x)
y2 = model(x)

nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)

assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3)     # same

拆分 PRNG 键#

当与 Flax NNX 转换(如 nnx.vmapnnx.pmap)交互时,通常需要拆分随机状态,以便每个副本都有自己唯一的状态。这可以通过两种方式完成

  • 在将键传递给 nnx.Rngs 流之一之前,手动拆分键;或

  • 通过使用 nnx.split_rngs 装饰器,它会自动拆分函数输入中找到的任何 nnx.RngStream 的随机状态,并在函数调用结束后自动“降低”它们。

使用 nnx.split_rngs 更方便,因为它与 Flax NNX 转换配合良好,因此这里有一个示例

rngs = nnx.Rngs(params=0, dropout=1)

@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
  print('Inside:')
  # rngs.dropout() # ValueError: fold_in accepts a single key...
  nnx.display(rngs)

f(rngs)

print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:

注意:nnx.split_rngs 允许将 NNX Filter 传递给 only 关键字参数,以便选择函数内部应拆分的 nnx.RngStream。在这种情况下,你只需要拆分 dropout PRNG 键流。

转换#

如前所述,在 Flax NNX 中,随机状态只是另一种类型的状态。这意味着,当涉及到 Flax NNX 转换时,它没有什么特殊之处,这意味着你应该能够使用每个转换的 Flax NNX 状态处理 API 来获得你想要的结果。

在本节中,你将通过两个在 Flax NNX 转换中使用随机状态的示例 - 一个使用 nnx.pmap,你将学习如何拆分 PRNG 状态,另一个使用 nnx.scan,你将冻结 PRNG 状态。

数据并行 dropout#

在第一个示例中,你将探索如何使用 nnx.pmap 在数据并行上下文中调用 nnx.Model

  • 由于 nnx.Model 使用 nnx.Dropout,你需要拆分 dropout 的随机状态,以确保每个副本获得不同的 dropout 掩码。

  • nnx.StateAxes 被传递给 in_axes,以指定 modeldropout PRNG 键流将在轴 0 上并行化,其余状态将被复制。

  • nnx.split_rngs 用于将 dropout PRNG 键流的键拆分为 N 个唯一的键,每个副本一个。

model = Model(nnx.Rngs(params=0, dropout=1))

num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})

@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
  return model(x)

y = forward(model, x)
print(y.shape)
(1, 16, 10)

循环 dropout#

接下来,让我们探讨如何实现一个使用循环 dropout 的 RNNCell。为此

  • 首先,你将创建一个 nnx.Dropout 层,该层将从自定义的 recurrent_dropout 流中采样 PRNG 键。

  • 你将对 RNNCell 的隐藏状态 h 应用 dropout (drop)。

  • 然后,定义一个 initial_state 函数来创建 RNNCell 的初始状态。

  • 最后,实例化 RNNCell

class Count(nnx.Variable): pass

class RNNCell(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
    self.dout = dout
    self.count = Count(jnp.array(0, jnp.uint32))

  def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
    h = self.drop(h) # Recurrent dropout.
    y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
    self.count += 1
    return y, y

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.dout))

cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))

接下来,你将使用 nnx.scanunroll 函数进行扫描,以实现 rnn_forward 操作

  • 循环 dropout 的关键在于在所有时间步上应用相同的 dropout 掩码。因此,为了实现这一点,你将 nnx.StateAxes 传递给 nnx.scanin_axes,指定将广播 cellrecurrent_dropout PRNG 流,并且其余的 RNNCell 状态将被传递。

  • 此外,隐藏状态 h 将是 nnx.scanCarry 变量,并且序列 x 将在其轴 1 上进行 scan 操作。

@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
  h = cell.initial_state(batch_size=x.shape[0])

  # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.
  state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
  @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
  def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
    h, y = cell(h, x)
    return h, y

  h, y = unroll(cell, h, x)
  return y

x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)

print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)