术语表#
有关其他术语,请参阅Jax 术语表。
- 绑定模块#
当
Module
通过常规 Python 对象构造创建时(例如 module = SomeModule(args…),它处于未绑定状态。这意味着仅设置数据类属性,并且没有变量绑定到模块。当纯函数Module.init()
或Module.apply()
被调用时,Flax 克隆模块并将变量绑定到它,并且模块的方法代码在本地绑定状态下执行,允许执行诸如直接调用子模块而无需提供变量的操作。有关更多详细信息,请参阅模块生命周期。- 紧凑/非紧凑模块#
具有单个方法的模块能够通过使用
@nn.compact
装饰器内联声明子模块和变量。这些被称为“紧凑风格模块”,而定义setup()
方法的模块(通常但并非总是具有多个可调用方法)被称为“设置风格模块”。要了解更多信息,请参阅setup 与 compact 指南。- 折叠#
给定输入 PRNG 密钥和整数生成新的 PRNG 密钥。通常在您想要生成新密钥但随后仍能够使用原始 rng 密钥时使用。您也可以使用jax.random.split 执行此操作,但这实际上会创建两个 RNG 密钥,这会更慢。请参阅我们的RNG 指南,了解 Flax 如何在
Modules
中自动生成新的 PRNG 密钥。- FrozenDict#
一个不可变字典,可以“解冻”为常规的可变字典。在内部,Flax 使用 FrozenDict 来确保变量字典不会意外地发生变异。注意:我们正在考虑从我们的 API 中返回到常规字典,并且仅在内部使用 FrozenDict。(请参阅#1223)。
- 函数核心#
flax 核心库为通过模型线程变量和 PRNG 实现简单的容器 Scope API,以及转换传递 Scope 对象的函数所需的提升机制。基于 python 类别的模块 API 建立在此核心库之上。
- 延迟初始化#
Flax 中的变量是延迟初始化的,仅在需要时才初始化。也就是说,在模块的正常执行期间,如果在提供的变量集合数据中找不到请求的变量名,我们会调用初始化器函数来创建它。这使我们能够在相同的代码路径下处理初始化和应用程序,从而简化了使用 JAX 转换与层。
- 提升转换#
请参阅Flax 文档。
- 模块#
一个数据类,允许以引用透明的形式定义和初始化参数。它负责存储和更新其自身内部的变量和参数。模块可以轻松地转换为函数,从而使它们可以与 JAX 转换(如 vmap 和 scan)轻松使用。
- 参数/参数#
“params”是变量字典 (dict) 中规范的变量集合。“params”集合通常包含可训练的权重。
- RNG 序列#
在 Flax
Modules
内部,您可以通过Module.make_rng()
获取新的PRNG 密钥。这些密钥可用于通过JAX 的函数式随机数生成器 生成随机数。拥有不同的 RNG 序列(例如,“params”和“dropout”)允许在多主机设置中进行细粒度控制(例如,在不同主机上以相同的方式初始化参数,但具有不同的丢弃掩码)并在提升转换 时对这些序列进行不同的处理。有关更多详细信息,请参阅RNG 指南。- 范围#
一个用于保存每层变量和 PRNG 密钥的容器类。
- 形状推断#
模块不需要在其定义中指定输入数组的形状。Flax 在初始化时会检查输入数组,并推断模型中参数的正确形状。
- TrainState#
- 变量#
位于变量集合 叶节点中的权重/参数/数据/数组。变量是在模块内部使用
Module.variable()
定义的。集合“params”的变量简称为参数,可以使用Module.param()
设置。- 变量集合#
变量字典中的条目,包含模型使用的权重/参数/数据/数组。“params”是变量字典中的规范集合。它们通常是可微的,由外部类似 SGD 的循环/优化器更新,而不是由前向传递代码直接修改。
- 变量字典#
包含变量集合 的字典。每个变量集合都是一个从字符串名称(例如,“params”或“batch_stats”)到具有变量 作为叶节点的(可能嵌套的)字典的映射,匹配子模块树结构。在Jax 文档 中阅读有关 pytree 和叶节点的更多信息。