Flax NNX 术语表#
有关其他术语,请参阅JAX 术语表。
- 过滤器#
一种从 Flax NNX 模块(
nnx.Module
)中仅提取某些 nnx.Variable 对象的方法。这通常通过在nnx.Module
上调用nnx.split
来完成。请参阅过滤器指南以了解更多信息。- 折叠到#
在 Flax 中,折叠到意味着给定一个输入 PRNG 密钥和整数,生成一个新的 JAX 伪随机数生成器 (PRNG) 密钥。当您想要生成新密钥但仍然能够使用原始 PRNG 密钥时,通常会使用此方法。您也可以在 JAX 中使用 jax.random.split 来实现此目的,但此方法实际上会创建两个 PRNG 密钥,这会比较慢。在随机性/PRNG 指南中了解 Flax 如何自动生成新的 PRNG 密钥。
- GraphDef#
nnx.GraphDef
是一个类,表示 Flax 模块(nnx.Module
)的所有静态、无状态和 Python 式部分。- 合并#
请参阅拆分和合并。
- 模块#
nnx.Module
是一个数据类,能够以引用透明的形式定义和初始化参数。它负责存储和更新自身内的 :term:`Variable` 对象和参数。- 参数 / parameters#
nnx.Param
是nnx.Variable
的一个特定子类,通常包含可训练的权重。- PRNG 状态#
Flax
nnx.Module
可以保留一个 伪随机数生成器 (PRNG) 状态对象nnx.Rngs
的引用,该对象可以生成新的 JAX PRNG 密钥。这些密钥用于通过 JAX 的函数式 PRNG 生成随机 JAX 数组。您可以使用具有不同种子的 PRNG 状态来为您的模型添加更精细的控制(例如,为参数和 dropout 掩码使用独立的随机数)。有关更多详细信息,请参阅 Flax 随机性/PRNG 指南。- 拆分和合并#
nnx.split
是一种用两部分表示nnx.Module
的方法:1) 一个静态的 Flax NNX GraphDef,它捕获其 Python 式静态信息;以及 2) 一个或多个 变量状态,它们以 JAX pytrees 的形式捕获其 JAX 数组(jax.Array
)。它们可以使用nnx.merge
合并回原始nnx.Module
。- 变换#
Flax NNX 变换 (transform) 是 JAX 变换的包装版本,它允许被变换的函数将 Flax NNX 模块(
nnx.Module
)作为输入或输出。例如,jax.jit 的“提升”版本是nnx.jit
。查看Flax NNX 变换指南以了解更多信息。- 变量#
Flax 模块 中存在的权重/参数/数据/数组
nnx.Variable
。变量在模块内部定义为nnx.Variable
或其子类。- 变量状态#
nnx.VariableState
是一个纯粹的函数式 JAX pytree,包含 模块 内的所有 变量。由于它是纯粹的,因此它可以作为 JAX 变换函数的输入或输出。nnx.VariableState
通过在nnx.Module
上使用nnx.split
获得。(请参阅拆分和模块以了解更多信息。)