提升变换#

⚠️ 高级主题 ⚠️

本设计说明解释了 flax.linen.transform 的底层实现,它允许在 Flax Module 中使用 JAX 变换。

介绍#

JAX 使用函数式 API,这意味着它仅在使用没有副作用的函数时才能保证正确行为(JAX 文档)。通常,这些副作用是由于对函数外部的对象进行突变造成的。

函数式范式具有一些优势,例如能够显式地推理状态和随机性。函数输出仅在输入参数发生变化时才会发生变化。因此,函数保证具有确定性行为。

但纯函数为 JAX 提供了另一个重大优势:具体来说,它们允许使用函数式变换。例如 jax.vmap(f) 将向量化函数 f。因为 f 不能有副作用,所以 f 的向量化/并行版本是定义明确的。要了解为什么我们需要此限制,请考虑如果 f 会递增计数器或绘制随机数会发生什么。f 会为向量中的每个项目绘制相同或不同的随机数吗?批次中的每个项目都会有自己的计数器,还是计数器在项目之间共享?如果 f 是并行计算的,计数器将按什么顺序递增?所有这些问题的答案都是“视情况而定”。行为是模棱两可的,函数式约束优雅地避免了这个问题。

Flax 引入了一种在与 JAX 兼容的形式下使用有限随机性和有状态变量的安全方法。Flax 中状态之所以没有问题,是因为它是局部的:在 Flax Module 内部,存在变量和 PRNG 序列,但在外部,只有 JAX 数组和 PRNG 密钥。

对于大多数用例,Flax 用于以有状态的方式定义模型。因为 Module 在外部的行为就像纯函数一样,我们可以充分利用 JAX 及其所有变换。但是,有些情况下我们希望将变换和 Module 结合起来,以实现两全其美。本设计说明解释了我们如何扩展 JAX 的函数式变换,以使其能够作用于具有内部状态和随机性的 Module

函数化#

在我们深入探讨细节之前,让我们考虑一个简单的示例,在该示例中,我们希望在 Module 中使用 vmap

首先,我们定义一个没有变换的简单 MLP

import jax
from jax import random, numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    h = nn.Dense(4, name='hidden')(xs)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

现在如果我们希望为 xs 中的每个项目都有单独的 MLP 参数怎么办?如果这是“普通 JAX”,我们可以想象编写类似于 jax.vmap(apply_mlp)(mlp_params, xs) 的代码。但是,在 Linen 中执行此操作实际上会导致失败

class NaiveVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP()
    return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs)  # fails

vmap 作用于 mlp 时,JAX 会引发错误,因为它不是 JAX 数组,也不是数组的简单容器。我们不能责怪 JAX 拒绝执行此未定义的任务。毕竟,目前还不清楚这里应该发生什么。MLP 内部的参数甚至尚未初始化,我们需要为每组参数提供一个单独的 PRNG 密钥。jax.vmap 只能沿某个轴进行广播或映射,但不能自动拆分 PRNG 密钥。因此,我们必须手动调用 jax.random.split

我们可以先将 MLP 转换为纯 init 和 apply 函数来解决这个问题。然后,我们使用 param 方法来存储参数

class ManualVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP(parent=None)
    init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
    apply_fn = jax.vmap(mlp.apply, in_axes=0)
    mlp_params = self.param('mlp', init_fn, xs)
    return apply_fn({'params': mlp_params}, xs)

xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        hidden: {
            bias: (3, 4),
            kernel: (3, 4, 4),
        },
        out: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

这里,MLP(parent=None) 创建了一个独立的 MLP 实例。这避免了为当前模块内部的子模块保留一个名称。虽然这不是严格必要的,但这也有助于确保我们不会意外地以有状态方式使用 MLP 实例,而被迫通过 .init.apply 来使用它。

此示例仍然比较简洁,但它已经需要一些额外的“簿记”语句才能使其正常工作。但是,此实现具有以下几个局限性

  1. 在初始化期间,我们通过 init_fnapply_fn 两次调用子模块。如果子模块使用相同的技巧执行函数式变换,随着模块调用的数量像 2^d 一样增长,我们最终将执行大量代码,其中 d 是嵌套函数变换的数量。

  2. 此实现假设子模块仅需要参数 RNG 序列。

  3. 此实现假设我们只在“params”集合中创建变量,并在 init 期间创建变量。但是,它不支持其他变量集合,也不支持在 apply 期间创建/更新变量。

第 3 点尤其使得手动函数化变得繁琐。请随时尝试在 MLP 模块中扩展上述示例,其中包含一个 nn.BatchNorm 层。这将需要处理一些额外的复杂性,例如存储更新的批次统计信息,并确保批次统计信息在 vmap 内不可变时(例如:评估模式)仍然不可变。

我们将将有状态的 Module 转换为纯函数的过程称为“函数化”。通过将有状态的 Module 暂时转换为函数,使其与 JAX 的函数式变换兼容。

提升#

Flax 为手动函数化提供了一种替代方法,我们称之为提升变换。提升变换在 flax.core.lift 中定义。所有提升的 JAX 变换都是通过一个名为 pack 的通用提升 API 定义的。

为了定义 pack,必须做出一些决定。 pack 的实现控制着变量和 rng 的提升方式,以及用户控制的粒度。它还必须决定是在变量定义还是变换定义时做出提升决定。

提升粒度#

借助 Linen API,用户可以定义任意变量集合和 PRNG 序列。每个集合中的变量都以相同的方式提升。

集合通常被赋予有语义意义的名称,例如“params”或“batch_stats”,而不是通用名称,例如“state”。因为集合具有语义意义,所以我们可以在变换级别决定每个集合应如何提升。例如,当我们将批次维度添加到模型时,我们希望共享所有参数变量。

同时,我们可以编写通用代码,使用变换而不必知道子模块将创建哪些类型的变量。因此,集合在精细粒度控制和通用性之间取得了平衡。我们还避免了循环遍历所有变量并尝试基于命名约定以特殊方式拆分集合的脆弱字符串匹配代码,例如:将所有名称前缀为“kernel”的变量作为目标。如果需要更精细的控制,用户只需将一组变量拆分为多个集合,这些集合应以不同的方式处理。

变换与变量控制#

提升行为可以在转换级别或变量定义期间定义。我们使用转换级别的提升行为定义。这样做的原因是,存在许多具有不同行为的不同转换。例如:vmap 具有广播和矢量化的参数,而 scan 具有扫描、携带和广播参数。一个变量需要为所有这些转换定义它的行为,否则一个 Module 将不兼容这些转换。或者,我们需要为如何处理转换做出默认决定。但是,这可能导致静默错误,因为给定用户意图,该行为可能实际上并不有效。

lift 包还提供了一个通用目的的 transform,它允许任意函数转换变量集合。例如,这可用于通过转置权重来绑定捆绑自动编码器中的权重。如果在变量定义时做出提升决定,则不清楚是否可以定义类似的通用目的转换。

Linen#

提升模块不知道 Linen 的 Module API。而是直接对 flax.core.Scope 实例进行操作。一个 Scope 实例包含 Module 的变量和 PRNG 序列。每个 Module 实例在 .scope 字段中都有一个 Scope 实例,如果它有一个父级或它是使用 initapply 创建的。通常,顶级 Module 实例(您在其中调用 initapply)是唯一没有绑定到它的 ScopeModule 实例。

当一个 Module 被转换时,我们使用 flax.core.lift API 来提升范围并使用 Module.clone() 来创建一个新的 Module 实例,并将提升的范围绑定到它。

flax.linen.transforms 公开对 flax.core.lift 中的转换的包装器。核心提升 API 对函数进行操作,而 Linen 包装器可以转换 Module 类或 Module 方法。

因此,提升独立于 Linen API 实现。这种关注点分离简化了实现,同时可能允许替代的 Module 抽象建立在提升和状态管理的通用核心之上。

实现#

pack(fn, in_vars, out_vars, rngs) API 经历以下阶段

  1. 范围去重

    此阶段仅与多个范围一起提升时相关。在这种情况下,我们必须首先找到根范围的集合。如果范围的任何祖先都不在需要提升的范围集合中,则该范围是根范围。

    通过仅提升根范围,我们避免了对相同变量进行两次提升。

    对于非根范围,我们存储对其祖先范围的引用以及一条路径,以便我们可以稍后重建它(阶段 4)。

  2. 筛选阶段

    变量和 PRNG 序列被分成组。这样,fn 可以分别将每一组提升到转换中。一组由指定为以下内容的过滤器定义:

    • 集合/prng 名称的列表

    • True(匹配所有内容)

    • False(不匹配任何内容)

    • DenyList(filter)(匹配除指定集合之外的所有内容(例如:DenyList(['params']) 匹配除“params”集合之外的所有内容)。)。

    一个集合或 PRNG 序列只能放入一组。如果一个集合匹配多个过滤器,它将被放入第一个匹配过滤器的组中。如果一个集合或 PRNG 序列不匹配任何过滤器,它将不会被提升。这意味着它不能在转换内部使用,尝试这样做会导致引发错误。例如,in_vars = (["params"], True) 将导致“params”集合被放入第一组,所有其他集合被放入第二组。

    对于匹配的每个 PRNG 序列,我们通过调用 make_rng 来播种一个新的 PRNG 序列。这避免了在提升的转换完成后更新 PRNG 状态的需要。

  3. 特定于转换的提升

    fn 被变量和 PRNG 组调用。JAX 转换具有不同的签名和提升选项。可以说最干净的示例是 vmap。在 vmap 的情况下,函数参数、PRNG 和变量集合被传递到一个 jax.vmap 包装函数中。

  4. 范围重建

    现在变量和 PRNG 已经在转换内部提升,我们希望重新创建提升的范围。Pack 使用一个 scope_fn 调用 fn,该函数接受提升的变量和 PRNG,并返回具有提升的变量和 rng 序列的重建范围。

  5. 重新打包阶段

    在使用提升的范围之后,我们必须检索更新的变量(PRNG 序列可以简单地丢弃)。pack 传递 repack_fn 来支持这一点。此阶段类似于阶段 2,只是我们只提升变量,忽略不可变变量。不可变变量不能更新。因此,它们不应从转换函数中返回。

  6. 提交阶段

    pack 预期 fn 返回一对,其中第一项将从 pack 中返回,第二项应为重新打包的变量。更新的变量存储在原始/未提升的范围中,以便在转换完成后,在转换内部发生的变异得以保留。

使用 pack 示例#

使用 pack 来转置变量集合中的每个矩阵的最小示例

from flax.core import lift
from flax.core import Scope, init, apply, nn as core_nn

def lift_transpose(fn, target='params', variables=True, rngs=True):
  # by default we transpose 'params' and simply pass through all other variables.
  def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args):
    # normally we would first call into a JAX transformed function here...
    target, rest = variable_groups
    def trans(x):
      if x.ndim == 2:
        return x.T
      return x
    target = jax.tree_util.tree_map(trans, target)
    variable_groups = (target, rest)
    scope = scope_fn(variable_groups, rng_groups)
    y = fn(scope, *args)
    out_variables = repack_fn(scope)
    return y, out_variables
  return lift.pack(
      wrapper,
      in_variable_filters=(target, variables),
      out_variable_filters=(variables,),
      rng_filters=(rngs,))

x = jnp.ones((3, 2))
y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4)

注意,大多数用户不需要直接与 pack 交互。当您发现现有的提升转换尚不支持的用例时,请打开 GitHub 问题。

支持的转换#

Jax 转换

Linen 中是否支持?

评论

vmap

scan

携带变量不能在扫描主体内部初始化。

remat

jit

当前实现可能会导致不必要的重新编译。

jvp

vjp

custom_vjp

custom_jvp

while_loop

携带变量不能在 while_loop 主体内部初始化。

cond

变量初始化/变异必须在分支之间结构上匹配。

switch

变量初始化/变异必须在分支之间结构上匹配。

pmap

xmap

参考资料

Linen 示例#

回到我们最初的示例,我们现在可以使用 nn.vmap 来简化我们的实现

class LinenVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs)

variables = LinenVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        Dense_0: {
            bias: (3, 4),
            kernel: (3, 2, 4),
        },
        Dense_1: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

这里我们使用 variable_axes={'params': 0} 来指示参数是矢量化的而不是共享的,而 split_rngs={'params': True} 意味着每个参数集都是独立初始化的。

我们还可以通过添加一个 BatchNorm 层来扩展示例以包含一些内部状态

class StatefulMLP(nn.Module):
  @nn.compact
  def __call__(self, x, *, train):
    h = nn.Dense(4, name='hidden')(x)
    h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

class LinenStatefulVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs, *, train):
    VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs, train=train)
variables = LinenStatefulVmapMLP().init(random.key(0), xs)

我们只需要在 nn.vmap 中添加 'batch_stats': 0,表示批处理统计信息是矢量化的,而不是沿着第一个轴共享的。

替代方案#

其他数值计算框架将变量视为一等公民。功能化的另一种选择是使用一个变量系统,该系统要么集成到 JAX 中,要么位于 JAX 之上。这样做的优点是,每个变量的提升变得更容易。如果变量是 JAX IR(JAXPR)的一部分,我们可以检查哪些变量必须在某个计算中提升。或者,可以使用集合标签对它们进行注释,以决定各种提升选项。

这种方法的缺点是变量系统更复杂。变量是相关的引用,并且打破了函数式编程的核心假设(参见 引用透明性)目前具有函数式接口的其他 API 可能也需要集成(例如:检查点和优化 API)。