Flax NNX 术语表

Flax NNX 术语表#

有关其他术语,请参阅JAX 术语表

过滤器#

一种从 Flax NNX 模块nnx.Module)中仅提取某些 nnx.Variable 对象的方法。这通常通过在 nnx.Module 上调用 nnx.split 来完成。请参阅过滤器指南以了解更多信息。

折叠到#

在 Flax 中,折叠到意味着给定一个输入 PRNG 密钥和整数,生成一个新的 JAX 伪随机数生成器 (PRNG) 密钥。当您想要生成新密钥但仍然能够使用原始 PRNG 密钥时,通常会使用此方法。您也可以在 JAX 中使用 jax.random.split 来实现此目的,但此方法实际上会创建两个 PRNG 密钥,这会比较慢。在随机性/PRNG 指南中了解 Flax 如何自动生成新的 PRNG 密钥。

GraphDef#

nnx.GraphDef 是一个类,表示 Flax 模块nnx.Module)的所有静态、无状态和 Python 式部分。

合并#

请参阅拆分和合并

模块#

nnx.Module 是一个数据类,能够以引用透明的形式定义和初始化参数。它负责存储和更新自身内的 :term:`Variable` 对象和参数。

参数 / parameters#

nnx.Paramnnx.Variable 的一个特定子类,通常包含可训练的权重。

PRNG 状态#

Flax nnx.Module 可以保留一个 伪随机数生成器 (PRNG) 状态对象 nnx.Rngs 的引用,该对象可以生成新的 JAX PRNG 密钥。这些密钥用于通过 JAX 的函数式 PRNG 生成随机 JAX 数组。您可以使用具有不同种子的 PRNG 状态来为您的模型添加更精细的控制(例如,为参数和 dropout 掩码使用独立的随机数)。有关更多详细信息,请参阅 Flax 随机性/PRNG 指南

拆分和合并#

nnx.split 是一种用两部分表示 nnx.Module 的方法:1) 一个静态的 Flax NNX GraphDef,它捕获其 Python 式静态信息;以及 2) 一个或多个 变量状态,它们以 JAX pytrees 的形式捕获其 JAX 数组jax.Array)。它们可以使用 nnx.merge 合并回原始 nnx.Module

变换#

Flax NNX 变换 (transform) 是 JAX 变换的包装版本,它允许被变换的函数将 Flax NNX 模块nnx.Module)作为输入或输出。例如,jax.jit 的“提升”版本是 nnx.jit。查看Flax NNX 变换指南以了解更多信息。

变量#

Flax 模块 中存在的权重/参数/数据/数组 nnx.Variable。变量在模块内部定义为 nnx.Variable 或其子类。

变量状态#

nnx.VariableState 是一个纯粹的函数式 JAX pytree,包含 模块 内的所有 变量。由于它是纯粹的,因此它可以作为 JAX 变换函数的输入或输出。nnx.VariableState 通过在 nnx.Module 上使用 nnx.split 获得。(请参阅拆分模块以了解更多信息。)