随机性#

class flax.nnx.Dropout(*args, **kwargs)[源代码]#

创建一个 dropout 层。

要使用 dropout,请调用 train() 方法(或者在构造函数或调用时传入 deterministic=False)。

要禁用 dropout,请调用 eval() 方法(或者在构造函数或调用时传入 deterministic=True)。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class MLP(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     x = self.dropout(x)
...     return x

>>> model = MLP(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 3))

>>> model.train() # use dropout
>>> model(x)
Array([[-0.9353421,  0.       ,  1.434417 ,  0.       ]], dtype=float32)

>>> model.eval() # don't use dropout
>>> model(x)
Array([[-0.46767104, -0.7213411 ,  0.7172085 , -0.31562346]], dtype=float32)
rate#

dropout 的概率。(_不是_保留率!)

类型

float

broadcast_dims#

将共享相同 dropout 掩码的维度

类型

collections.abc.Sequence[int]

deterministic#

如果为 false,则输入按 1 / (1 - rate) 缩放并进行掩码,如果为 true,则不应用掩码,并按原样返回输入。

类型

bool

rng_collection#

请求 rng 密钥时使用的 rng 集合名称。

类型

str

rngs#

rng 密钥。

类型

flax.nnx.rnglib.Rngs | None