术语表

术语表#

有关其他术语,请参阅Jax 术语表

绑定模块#

Module 通过常规 Python 对象构造创建时(例如 module = SomeModule(args…),它处于未绑定状态。这意味着仅设置数据类属性,并且没有变量绑定到模块。当纯函数Module.init()Module.apply() 被调用时,Flax 克隆模块并将变量绑定到它,并且模块的方法代码在本地绑定状态下执行,允许执行诸如直接调用子模块而无需提供变量的操作。有关更多详细信息,请参阅模块生命周期

紧凑/非紧凑模块#

具有单个方法的模块能够通过使用@nn.compact 装饰器内联声明子模块和变量。这些被称为“紧凑风格模块”,而定义setup() 方法的模块(通常但并非总是具有多个可调用方法)被称为“设置风格模块”。要了解更多信息,请参阅setup 与 compact 指南

折叠#

给定输入 PRNG 密钥和整数生成新的 PRNG 密钥。通常在您想要生成新密钥但随后仍能够使用原始 rng 密钥时使用。您也可以使用jax.random.split 执行此操作,但这实际上会创建两个 RNG 密钥,这会更慢。请参阅我们的RNG 指南,了解 Flax 如何在Modules 中自动生成新的 PRNG 密钥。

FrozenDict#

一个不可变字典,可以“解冻”为常规的可变字典。在内部,Flax 使用 FrozenDict 来确保变量字典不会意外地发生变异。注意:我们正在考虑从我们的 API 中返回到常规字典,并且仅在内部使用 FrozenDict。(请参阅#1223)。

函数核心#

flax 核心库为通过模型线程变量和 PRNG 实现简单的容器 Scope API,以及转换传递 Scope 对象的函数所需的提升机制。基于 python 类别的模块 API 建立在此核心库之上。

延迟初始化#

Flax 中的变量是延迟初始化的,仅在需要时才初始化。也就是说,在模块的正常执行期间,如果在提供的变量集合数据中找不到请求的变量名,我们会调用初始化器函数来创建它。这使我们能够在相同的代码路径下处理初始化和应用程序,从而简化了使用 JAX 转换与层。

提升转换#

请参阅Flax 文档

模块#

一个数据类,允许以引用透明的形式定义和初始化参数。它负责存储和更新其自身内部的变量和参数。模块可以轻松地转换为函数,从而使它们可以与 JAX 转换(如 vmapscan)轻松使用。

参数/参数#

“params”是变量字典 (dict) 中规范的变量集合。“params”集合通常包含可训练的权重。

RNG 序列#

在 Flax Modules 内部,您可以通过Module.make_rng() 获取新的PRNG 密钥。这些密钥可用于通过JAX 的函数式随机数生成器 生成随机数。拥有不同的 RNG 序列(例如,“params”和“dropout”)允许在多主机设置中进行细粒度控制(例如,在不同主机上以相同的方式初始化参数,但具有不同的丢弃掩码)并在提升转换 时对这些序列进行不同的处理。有关更多详细信息,请参阅RNG 指南

范围#

一个用于保存每层变量和 PRNG 密钥的容器类。

形状推断#

模块不需要在其定义中指定输入数组的形状。Flax 在初始化时会检查输入数组,并推断模型中参数的正确形状。

TrainState#

请参阅flax.training.train_state.TrainState

变量#

位于变量集合 叶节点中的权重/参数/数据/数组。变量是在模块内部使用Module.variable() 定义的。集合“params”的变量简称为参数,可以使用Module.param() 设置。

变量集合#

变量字典中的条目,包含模型使用的权重/参数/数据/数组。“params”是变量字典中的规范集合。它们通常是可微的,由外部类似 SGD 的循环/优化器更新,而不是由前向传递代码直接修改。

变量字典#

包含变量集合 的字典。每个变量集合都是一个从字符串名称(例如,“params”或“batch_stats”)到具有变量 作为叶节点的(可能嵌套的)字典的映射,匹配子模块树结构。在Jax 文档 中阅读有关 pytree 和叶节点的更多信息。