rnglib#
- class flax.nnx.Rngs(*args, **kwargs)[source]#
NNX rng 容器类。要实例化
Rngs
,请传入一个整数,指定起始种子。Rngs
可以有不同的“流”,允许用户生成不同的 rng 键。例如,要为params
和dropout
流生成键>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rng1 = nnx.Rngs(0, params=1) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() != rng2.dropout()
因为我们传入了
params=1
,所以params
的起始种子为1
,而dropout
的起始种子默认为我们传入的0
,因为我们没有为dropout
指定种子。如果我们没有为params
指定种子,那么两个流都将默认为使用我们传入的0
>>> rng1 = nnx.Rngs(0) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() == rng2.dropout()
Rngs
容器类为每个流包含一个单独的计数器。每次调用流以生成新的 rng 键时,计数器都会递增1
。要生成新的 rng 键,我们将当前 rng 流的计数器值折叠到其对应的起始种子中。如果我们尝试为实例化时未指定的流生成 rng 键,则使用default
流(即,在实例化期间传递给Rngs
的第一个位置参数是default
起始种子)>>> rng1 = nnx.Rngs(100, params=42) >>> # `params` stream starting seed is 42, counter is 0 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0) >>> # `dropout` stream starting seed is defaulted to 100, counter is 0 >>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0) >>> # empty stream starting seed is defaulted to 100, counter is 1 >>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1) >>> # `params` stream starting seed is 42, counter is 1 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)
让我们看一个在
Module
中使用Rngs
的示例,并通过手动线程化Rngs
来验证输出>>> class Model(nnx.Module): ... def __init__(self, rngs): ... # Linear uses the `params` stream twice for kernel and bias ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... # Dropout uses the `dropout` stream once ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) >>> def assert_same(x, rng_seed, **rng_kwargs): ... model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs)) ... out = model(x) ... ... # manual forward propagation ... rngs = nnx.Rngs(rng_seed, **rng_kwargs) ... kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3)) ... assert (model.linear.kernel.value==kernel).all() ... bias = nnx.initializers.zeros_init()(rngs.params(), (3,)) ... assert (model.linear.bias.value==bias).all() ... mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3)) ... # dropout scales the output proportional to the dropout rate ... manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5 ... assert (out == manual_out).all() >>> x = jnp.ones((1, 2)) >>> assert_same(x, 0) >>> assert_same(x, 0, params=1) >>> assert_same(x, 0, params=1, dropout=2)
- flax.nnx.reseed(node, /, **stream_keys)[source]#
使用新密钥更新指定 RNG 流的密钥。
- 参数
node – 要在其中重新设置 RNG 流的节点。
**stream_keys – 流名称到新密钥的映射。密钥可以是整数或 jax 数组。如果传入整数,则将使用
jax.random.key
生成密钥。
- 引发
ValueError – 如果现有流密钥不是标量。
示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool)