Flax 哲学#

没有特定的顺序

  • 库代码应该易于阅读和理解。

  • 优先复制代码而不是糟糕的抽象。

  • 通常,优先复制代码而不是向函数添加选项。

  • 注释驱动的设计:如果难以记录代码,请考虑更改设计。

  • 单元测试驱动的设计:如果难以测试代码,请考虑更改设计。

  • 人们通过复制现有实现来启动项目——使基本实现出色。

  • 如果我们向开发人员公开抽象,我们就拥有心理负担。

  • 面向开发人员的函数式编程抽象会让一些用户感到困惑,在优势很高的地方公开它们。

  • “阅读手册”不是对开发人员困惑的适当回应。框架应该引导开发人员走向好的解决方案,例如通过断言和错误消息。

  • 无用的错误消息是一个错误。

  • “调试比最初编写代码难两倍。因此,如果你以尽可能巧妙的方式编写代码,那么从定义上来说,你还不够聪明,无法调试它。”——布莱恩·克尼汉

设计原则#

Flax 是建立在 JAX 之上的神经网络库,已被越来越多的用户采用,最值得注意的是在 MLPerf 0.7 基准测试的 JAX 提交中。我们在过去一年(以及与用户和 JAX 核心开发人员的许多对话)中的经验指导了称为 Linen (flax.linen) 的 API 的重新设计,以回应以下基本设计问题。

神经网络库如何从建立在 JAX 之上并利用 JAX 的独特优势中受益?#

世界已经有了 TensorFlow 和 PyTorch,没有必要构建它们的克隆。我们认为,JAX 采用的可组合函数变换方法为使神经网络代码比现有库更易于维护、更具可扩展性和更高性能开辟了新领域。虽然我们努力提供一个对那些熟悉 Keras/Sonnet/PyTorch 的人来说熟悉的 API,但 Linen 本质上是一个用于在 JAX 中定义神经网络的函数式系统。仅举几个例子说明我们认为一个以 JAX 为目标的库可以实现什么

  • 将模型编写为“单样本”代码,并使用 jax.vmap 自动引入批处理。

  • 在 NLP 和其他掩码问题中自动处理参差不齐的批次。

  • 通过利用重新计算的 scan 为大型卷积网络创建高效的编译时和运行时模型。

  • 通过启用轻松的重新计算、可逆性和模型并行数据分片来消除内存问题。

如何与 JAX 变换交互?#

可以说,神经网络库的全部意义在于提供一个隐式变量管理 API,以节省用户手动将数千个变量通过复杂的函数树传递的麻烦。但是,JAX 在纯函数上运行。为了处理当前和将来的 JAX 变换(以任何方式配置和组合),Linen 模块直接“函数化”,也就是说,自动就地转换为形式为以下形式的显式函数

\[f \left( v_{in}, x \right) \rightarrow v_{out}, y\]

其中 \(v_{in}\) 是模型使用的变量集合和 PRNG 状态,\(v_{out}\) 是经过修改的输出变量集合,\(x\) 是输入数据,\(y\) 是输出数据。应用 JAX 变换然后简化为为各种变量集合和 PRNG 状态指定任何特定于参数的变换选项。这释放了 JAX 变换 的灵活性和强大功能——例如,可以通过使用 jax.pmap 以不同的方式实现设备并行训练或每个设备集成,而无需任何明确的库支持。此外,**在 模块 内**,我们公开了一些围绕复杂 JAX 变换(例如 jax.vmapjax.lax.scan)的轻量级包装器,这些包装器注释了 JAX 将如何变换每个变量集合。重要的是,我们正确处理在映射和循环变换下创建新变量和变换变量的非平凡情况,以进行初始化和应用。

参数如何表示?我们如何处理更新有状态变量的通用“可微算法”?#

我们遵循 JAX 函数式约定,将数据存储在“pytree”中:嵌套在嵌套元组、列表、字典中的 JAX 数组。因为研究人员不可避免地会手动与这些数据交互,所以我们使用具有有意义的默认键的嵌套字典,并提供几个实用程序(遍历等)来直接处理它们。Linen 使用 Python 冻结字典的加速版本,该版本缓存其 JAX 扁平化形式,以加快 jited 函数调用的开销。

Flax 通过允许模型接受多个不同“类型”的集合来概括神经网络的操作:参数、批次归一化统计、自回归缓存、调试信息、细粒度超参数等。每个集合都存储在一个与模型具有相同结构的嵌套字典中。重要的是,我们将这些不同的类型混淆在“状态”这个单一的模糊名词之下,而是将不同逻辑类型的变量分开,这些变量在 JAX 变换和变异(例如,训练与预测)下可以不同地对待。同样,我们允许在 模块 内存在多个独立的命名 PRNG 链中,以便分别处理随机性,用于不同的应用程序,例如初始化、丢弃、采样等。

在每个阶段,与神经网络相关联的数据不会保留在自定义对象层次结构中,而是保留在显式、Python 和 JAX 原生形式中,这很容易进行内省和修改。用户利用它将 TF 和 PyTorch 检查点映射到 Flax,实现特定于子模型的损失项,并执行快速模型手术等。为了保存这些数据,大多数 Flax 示例通过高效的“msgpack”二进制格式存储这些嵌套字典——但是,由于变量只是 Python 字典,因此你可以直接使用任何(不了解 JAX 的)序列化库。

如何与纯函数式 JAX 代码交互?#

为了广泛地对 JAX 生态系统有用,用户不应该需要大量重构他们的代码,以便为给定的数值任务添加“可训练性”。“库不应该妨碍。”在 Linen 中使用纯函数式代码很简单:模块 实现只是具有命名变量的 JAX 代码。在其他纯函数式代码中使用 Linen 模块可以像使用单个顶层模块转换一样简单,以允许初始化和纯应用可能包含各种可训练部分的任何 JAX 程序。