rnglib#

class flax.nnx.Rngs(*args, **kwargs)[source]#

NNX rng 容器类。要实例化 Rngs,请传入一个整数,指定起始种子。Rngs 可以有不同的“流”,允许用户生成不同的 rng 键。例如,要为 paramsdropout 流生成键

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rng1 = nnx.Rngs(0, params=1)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() != rng2.dropout()

因为我们传入了 params=1,所以 params 的起始种子为 1,而 dropout 的起始种子默认为我们传入的 0,因为我们没有为 dropout 指定种子。如果我们没有为 params 指定种子,那么两个流都将默认为使用我们传入的 0

>>> rng1 = nnx.Rngs(0)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() == rng2.dropout()

Rngs 容器类为每个流包含一个单独的计数器。每次调用流以生成新的 rng 键时,计数器都会递增 1。要生成新的 rng 键,我们将当前 rng 流的计数器值折叠到其对应的起始种子中。如果我们尝试为实例化时未指定的流生成 rng 键,则使用 default 流(即,在实例化期间传递给 Rngs 的第一个位置参数是 default 起始种子)

>>> rng1 = nnx.Rngs(100, params=42)
>>> # `params` stream starting seed is 42, counter is 0
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0)
>>> # `dropout` stream starting seed is defaulted to 100, counter is 0
>>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0)
>>> # empty stream starting seed is defaulted to 100, counter is 1
>>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1)
>>> # `params` stream starting seed is 42, counter is 1
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)

让我们看一个在 Module 中使用 Rngs 的示例,并通过手动线程化 Rngs 来验证输出

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     # Linear uses the `params` stream twice for kernel and bias
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     # Dropout uses the `dropout` stream once
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))

>>> def assert_same(x, rng_seed, **rng_kwargs):
...   model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs))
...   out = model(x)
...
...   # manual forward propagation
...   rngs = nnx.Rngs(rng_seed, **rng_kwargs)
...   kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
...   assert (model.linear.kernel.value==kernel).all()
...   bias = nnx.initializers.zeros_init()(rngs.params(), (3,))
...   assert (model.linear.bias.value==bias).all()
...   mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3))
...   # dropout scales the output proportional to the dropout rate
...   manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5
...   assert (out == manual_out).all()

>>> x = jnp.ones((1, 2))
>>> assert_same(x, 0)
>>> assert_same(x, 0, params=1)
>>> assert_same(x, 0, params=1, dropout=2)
__init__(default=None, /, **rngs)[source]#
参数
  • defaultdefault 流的起始种子。从 **rngs 关键字参数中未指定的流生成的任何键都将默认使用此起始种子。

  • **rngs – 可选的关键字参数,用于为不同的 rng 流指定起始种子。关键字是流名称,其值是该流对应的起始种子。

class flax.nnx.RngStream(*args: 'Any', **kwargs: 'Any')[source]#
flax.nnx.reseed(node, /, **stream_keys)[source]#

使用新密钥更新指定 RNG 流的密钥。

参数
  • node – 要在其中重新设置 RNG 流的节点。

  • **stream_keys – 流名称到新密钥的映射。密钥可以是整数或 jax 数组。如果传入整数,则将使用 jax.random.key 生成密钥。

引发

ValueError – 如果现有流密钥不是标量。

示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)