提升变换#
⚠️ 高级主题 ⚠️
本设计说明解释了 flax.linen.transform
的底层实现,它允许在 Flax Module
中使用 JAX 变换。
介绍#
JAX 使用函数式 API,这意味着它仅在使用没有副作用的函数时才能保证正确行为(JAX 文档)。通常,这些副作用是由于对函数外部的对象进行突变造成的。
函数式范式具有一些优势,例如能够显式地推理状态和随机性。函数输出仅在输入参数发生变化时才会发生变化。因此,函数保证具有确定性行为。
但纯函数为 JAX 提供了另一个重大优势:具体来说,它们允许使用函数式变换。例如 jax.vmap(f)
将向量化函数 f
。因为 f
不能有副作用,所以 f
的向量化/并行版本是定义明确的。要了解为什么我们需要此限制,请考虑如果 f
会递增计数器或绘制随机数会发生什么。f
会为向量中的每个项目绘制相同或不同的随机数吗?批次中的每个项目都会有自己的计数器,还是计数器在项目之间共享?如果 f
是并行计算的,计数器将按什么顺序递增?所有这些问题的答案都是“视情况而定”。行为是模棱两可的,函数式约束优雅地避免了这个问题。
Flax 引入了一种在与 JAX 兼容的形式下使用有限随机性和有状态变量的安全方法。Flax 中状态之所以没有问题,是因为它是局部的:在 Flax Module
内部,存在变量和 PRNG 序列,但在外部,只有 JAX 数组和 PRNG 密钥。
对于大多数用例,Flax 用于以有状态的方式定义模型。因为 Module
在外部的行为就像纯函数一样,我们可以充分利用 JAX 及其所有变换。但是,有些情况下我们希望将变换和 Module
结合起来,以实现两全其美。本设计说明解释了我们如何扩展 JAX 的函数式变换,以使其能够作用于具有内部状态和随机性的 Module
。
函数化#
在我们深入探讨细节之前,让我们考虑一个简单的示例,在该示例中,我们希望在 Module
中使用 vmap
。
首先,我们定义一个没有变换的简单 MLP
import jax
from jax import random, numpy as jnp
from flax import linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, xs):
h = nn.Dense(4, name='hidden')(xs)
h = nn.relu(h)
return nn.Dense(1, name='out')(h)
现在如果我们希望为 xs
中的每个项目都有单独的 MLP 参数怎么办?如果这是“普通 JAX”,我们可以想象编写类似于 jax.vmap(apply_mlp)(mlp_params, xs)
的代码。但是,在 Linen 中执行此操作实际上会导致失败
class NaiveVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs):
mlp = MLP()
return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs) # fails
当 vmap
作用于 mlp
时,JAX 会引发错误,因为它不是 JAX 数组,也不是数组的简单容器。我们不能责怪 JAX 拒绝执行此未定义的任务。毕竟,目前还不清楚这里应该发生什么。MLP 内部的参数甚至尚未初始化,我们需要为每组参数提供一个单独的 PRNG 密钥。jax.vmap
只能沿某个轴进行广播或映射,但不能自动拆分 PRNG 密钥。因此,我们必须手动调用 jax.random.split
。
我们可以先将 MLP
转换为纯 init 和 apply 函数来解决这个问题。然后,我们使用 param
方法来存储参数
class ManualVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs):
mlp = MLP(parent=None)
init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
apply_fn = jax.vmap(mlp.apply, in_axes=0)
mlp_params = self.param('mlp', init_fn, xs)
return apply_fn({'params': mlp_params}, xs)
xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
mlp: {
hidden: {
bias: (3, 4),
kernel: (3, 4, 4),
},
out: {
bias: (3, 1),
kernel: (3, 4, 1),
},
},
}
"""
这里,MLP(parent=None)
创建了一个独立的 MLP
实例。这避免了为当前模块内部的子模块保留一个名称。虽然这不是严格必要的,但这也有助于确保我们不会意外地以有状态方式使用 MLP 实例,而被迫通过 .init
或 .apply
来使用它。
此示例仍然比较简洁,但它已经需要一些额外的“簿记”语句才能使其正常工作。但是,此实现具有以下几个局限性
在初始化期间,我们通过
init_fn
和apply_fn
两次调用子模块。如果子模块使用相同的技巧执行函数式变换,随着模块调用的数量像 2^d 一样增长,我们最终将执行大量代码,其中 d 是嵌套函数变换的数量。此实现假设子模块仅需要参数 RNG 序列。
此实现假设我们只在“params”集合中创建变量,并在
init
期间创建变量。但是,它不支持其他变量集合,也不支持在apply
期间创建/更新变量。
第 3 点尤其使得手动函数化变得繁琐。请随时尝试在 MLP
模块中扩展上述示例,其中包含一个 nn.BatchNorm
层。这将需要处理一些额外的复杂性,例如存储更新的批次统计信息,并确保批次统计信息在 vmap
内不可变时(例如:评估模式)仍然不可变。
我们将将有状态的 Module 转换为纯函数的过程称为“函数化”。通过将有状态的 Module
暂时转换为函数,使其与 JAX 的函数式变换兼容。
提升#
Flax 为手动函数化提供了一种替代方法,我们称之为提升变换。提升变换在 flax.core.lift
中定义。所有提升的 JAX 变换都是通过一个名为 pack
的通用提升 API 定义的。
为了定义 pack
,必须做出一些决定。 pack
的实现控制着变量和 rng 的提升方式,以及用户控制的粒度。它还必须决定是在变量定义还是变换定义时做出提升决定。
提升粒度#
借助 Linen API,用户可以定义任意变量集合和 PRNG 序列。每个集合中的变量都以相同的方式提升。
集合通常被赋予有语义意义的名称,例如“params”或“batch_stats”,而不是通用名称,例如“state”。因为集合具有语义意义,所以我们可以在变换级别决定每个集合应如何提升。例如,当我们将批次维度添加到模型时,我们希望共享所有参数变量。
同时,我们可以编写通用代码,使用变换而不必知道子模块将创建哪些类型的变量。因此,集合在精细粒度控制和通用性之间取得了平衡。我们还避免了循环遍历所有变量并尝试基于命名约定以特殊方式拆分集合的脆弱字符串匹配代码,例如:将所有名称前缀为“kernel”的变量作为目标。如果需要更精细的控制,用户只需将一组变量拆分为多个集合,这些集合应以不同的方式处理。
变换与变量控制#
提升行为可以在转换级别或变量定义期间定义。我们使用转换级别的提升行为定义。这样做的原因是,存在许多具有不同行为的不同转换。例如:vmap
具有广播和矢量化的参数,而 scan
具有扫描、携带和广播参数。一个变量需要为所有这些转换定义它的行为,否则一个 Module
将不兼容这些转换。或者,我们需要为如何处理转换做出默认决定。但是,这可能导致静默错误,因为给定用户意图,该行为可能实际上并不有效。
lift 包还提供了一个通用目的的 transform
,它允许任意函数转换变量集合。例如,这可用于通过转置权重来绑定捆绑自动编码器中的权重。如果在变量定义时做出提升决定,则不清楚是否可以定义类似的通用目的转换。
Linen#
提升模块不知道 Linen 的 Module
API。而是直接对 flax.core.Scope
实例进行操作。一个 Scope
实例包含 Module
的变量和 PRNG 序列。每个 Module
实例在 .scope
字段中都有一个 Scope
实例,如果它有一个父级或它是使用 init
或 apply
创建的。通常,顶级 Module
实例(您在其中调用 init
或 apply
)是唯一没有绑定到它的 Scope
的 Module
实例。
当一个 Module
被转换时,我们使用 flax.core.lift
API 来提升范围并使用 Module.clone()
来创建一个新的 Module
实例,并将提升的范围绑定到它。
flax.linen.transforms
公开对 flax.core.lift
中的转换的包装器。核心提升 API 对函数进行操作,而 Linen 包装器可以转换 Module
类或 Module
方法。
因此,提升独立于 Linen API 实现。这种关注点分离简化了实现,同时可能允许替代的 Module
抽象建立在提升和状态管理的通用核心之上。
实现#
该 pack(fn, in_vars, out_vars, rngs)
API 经历以下阶段
范围去重
此阶段仅与多个范围一起提升时相关。在这种情况下,我们必须首先找到根范围的集合。如果范围的任何祖先都不在需要提升的范围集合中,则该范围是根范围。
通过仅提升根范围,我们避免了对相同变量进行两次提升。
对于非根范围,我们存储对其祖先范围的引用以及一条路径,以便我们可以稍后重建它(阶段 4)。
筛选阶段
变量和 PRNG 序列被分成组。这样,
fn
可以分别将每一组提升到转换中。一组由指定为以下内容的过滤器定义:集合/prng 名称的列表
True
(匹配所有内容)False
(不匹配任何内容)DenyList(filter)
(匹配除指定集合之外的所有内容(例如:DenyList(['params'])
匹配除“params”集合之外的所有内容)。)。
一个集合或 PRNG 序列只能放入一组。如果一个集合匹配多个过滤器,它将被放入第一个匹配过滤器的组中。如果一个集合或 PRNG 序列不匹配任何过滤器,它将不会被提升。这意味着它不能在转换内部使用,尝试这样做会导致引发错误。例如,
in_vars = (["params"], True)
将导致“params”集合被放入第一组,所有其他集合被放入第二组。对于匹配的每个 PRNG 序列,我们通过调用
make_rng
来播种一个新的 PRNG 序列。这避免了在提升的转换完成后更新 PRNG 状态的需要。特定于转换的提升
fn
被变量和 PRNG 组调用。JAX 转换具有不同的签名和提升选项。可以说最干净的示例是vmap
。在 vmap 的情况下,函数参数、PRNG 和变量集合被传递到一个jax.vmap
包装函数中。范围重建
现在变量和 PRNG 已经在转换内部提升,我们希望重新创建提升的范围。Pack 使用一个
scope_fn
调用fn
,该函数接受提升的变量和 PRNG,并返回具有提升的变量和 rng 序列的重建范围。重新打包阶段
在使用提升的范围之后,我们必须检索更新的变量(PRNG 序列可以简单地丢弃)。pack 传递
repack_fn
来支持这一点。此阶段类似于阶段 2,只是我们只提升变量,忽略不可变变量。不可变变量不能更新。因此,它们不应从转换函数中返回。提交阶段
pack
预期fn
返回一对,其中第一项将从 pack 中返回,第二项应为重新打包的变量。更新的变量存储在原始/未提升的范围中,以便在转换完成后,在转换内部发生的变异得以保留。
使用 pack 示例#
使用 pack
来转置变量集合中的每个矩阵的最小示例
from flax.core import lift
from flax.core import Scope, init, apply, nn as core_nn
def lift_transpose(fn, target='params', variables=True, rngs=True):
# by default we transpose 'params' and simply pass through all other variables.
def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args):
# normally we would first call into a JAX transformed function here...
target, rest = variable_groups
def trans(x):
if x.ndim == 2:
return x.T
return x
target = jax.tree_util.tree_map(trans, target)
variable_groups = (target, rest)
scope = scope_fn(variable_groups, rng_groups)
y = fn(scope, *args)
out_variables = repack_fn(scope)
return y, out_variables
return lift.pack(
wrapper,
in_variable_filters=(target, variables),
out_variable_filters=(variables,),
rng_filters=(rngs,))
x = jnp.ones((3, 2))
y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4)
注意,大多数用户不需要直接与 pack
交互。当您发现现有的提升转换尚不支持的用例时,请打开 GitHub 问题。
支持的转换#
Jax 转换 |
Linen 中是否支持? |
评论 |
---|---|---|
vmap |
✅ |
|
scan |
✅ |
携带变量不能在扫描主体内部初始化。 |
remat |
✅ |
|
jit |
✅ |
当前实现可能会导致不必要的重新编译。 |
jvp |
✅ |
|
vjp |
✅ |
|
custom_vjp |
✅ |
|
custom_jvp |
❌ |
|
while_loop |
✅ |
携带变量不能在 while_loop 主体内部初始化。 |
cond |
✅ |
变量初始化/变异必须在分支之间结构上匹配。 |
switch |
✅ |
变量初始化/变异必须在分支之间结构上匹配。 |
pmap |
❌ |
|
xmap |
❌ |
参考资料
Linen 示例#
回到我们最初的示例,我们现在可以使用 nn.vmap
来简化我们的实现
class LinenVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs):
VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0)
return VmapMLP(name='mlp')(xs)
variables = LinenVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
mlp: {
Dense_0: {
bias: (3, 4),
kernel: (3, 2, 4),
},
Dense_1: {
bias: (3, 1),
kernel: (3, 4, 1),
},
},
}
"""
这里我们使用 variable_axes={'params': 0}
来指示参数是矢量化的而不是共享的,而 split_rngs={'params': True}
意味着每个参数集都是独立初始化的。
我们还可以通过添加一个 BatchNorm
层来扩展示例以包含一些内部状态
class StatefulMLP(nn.Module):
@nn.compact
def __call__(self, x, *, train):
h = nn.Dense(4, name='hidden')(x)
h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train)
h = nn.relu(h)
return nn.Dense(1, name='out')(h)
class LinenStatefulVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs, *, train):
VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
return VmapMLP(name='mlp')(xs, train=train)
variables = LinenStatefulVmapMLP().init(random.key(0), xs)
我们只需要在 nn.vmap
中添加 'batch_stats': 0
,表示批处理统计信息是矢量化的,而不是沿着第一个轴共享的。