Flax 模块生命周期#
这份设计说明针对的是已经熟悉 Flax Linen 模块但希望进一步了解抽象背后设计原则的用户。这份说明将帮助您深入理解模块 API 建立在的假设和保证。如果您还没有实际使用过模块,请查看快速入门指南。
Flax Linen 模块在 Flax 核心之上提供了一个 Pythonic 抽象。模块 抽象允许您创建在 JAX 之上具有状态、参数和随机性的类。这是一份有关Module
类的设计和行为的实用指南。最后,您应该可以放心地走出常规,以新方式使用模块。
概述#
定义#
让我们从模块生命周期的概述开始。首先,定义一个简单的模块
class MLP(nn.Module):
# 1. Attribute annotations
hidden_size: int
out_size: int
# 2. The ``setup`` method
def setup(self):
self.hidden = nn.Dense(self.hidden_size)
self.out = nn.Dense(self.out_size)
# 3. User methods
def __call__(self, x):
a = self.hidden(x)
h = nn.relu(a)
return self.out(h)
这个模块由以下部分组成:
**属性注释**,定义为dataclass 字段。这些注释会自动定义构造函数。
**``setup`` 方法**,它创建子模块并将它们分配给属性。
**用户方法**。按照惯例,大多数模块只有一个
__call__
方法,但您可以定义多个方法或使用不同的方法名称。
构造/初始化#
现在,我们要构造和使用MLP
模块
mlp = MLP(hidden_size=5, out_size=3)
x = jax.numpy.ones((1, 2))
variables = mlp.init(random.key(0), x)
y = mlp.apply(variables, x)
首先,我们构造一个MLP
实例,并传入构造属性。请注意,这里的构造与您在不熟悉函数式编程模式的情况下可能期望的不同。MLP
构造函数实际上并没有创建变量或任何内部状态。最好将其视为包含功能但没有数据的模块的规范或模板。
让我们仔细看看初始化。令人惊讶的是,Flax 中实际上没有单独的初始化路径。调用init
只是apply
的一个特例,您也可以将其写成
# equivalent to: variables = mlp.init(random.key(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True)
因此,init
只是apply
的一个包装器,其中
我们调用没有任何初始变量(空字典)的模块。
始终会传入名为
"params"
的 PRNG 生成器,用于使用参数初始化函数随机初始化参数。所有变量集合都设置为可变的 (
mutable=True
)。当集合可变时,可以更新现有变量并创建新变量。因此,在init
内部,可以在任何变量集合中初始化变量,并且它们都将被添加到返回的变量字典中。
生命周期#
既然您已经了解了init
是apply
的一个特例,那么让我们更详细地了解.apply(...)
。实际上,模块的大部分复杂性都位于apply
方法中。“模块生命周期”包括构造和apply
模块。我们可以将模块生命周期总结如下:
我们构造
mlp = MLP(hidden_size=5, out_size=3)
,因此mlp.hidden_size=5
且mlp.out_size=3
。然后,调用
mlp.apply
,它会克隆
mlp
,我们将其称为mlp_copy
。调用
mlp_copy.setup()
。返回
mlp_copy.__call__()
的输出,并可选地返回使用关键字参数mutable=
指定为可变的变量集合。
请注意,生命周期包括克隆模块实例。这样做是为了确保apply
可以被视为一个纯函数(即,如果传入相同的参数,它将返回相同的输出)。您将在后面的顶层模块 部分中详细了解这一点。
变量#
术语“变量”在编程和数学中很常见。但是,重要的是要了解 JAX 和 Flax 中变量的含义。在 Flax 模块内部,变量 的行为与您对 Python 的期望一致。它们被初始化一次,被读取,并且可能还会定期更新。但是,JAX 没有变量的概念。相反,值存储在类似于 NumPy 数组的数组中 - 它们有一个重要的区别:它们是不可变的。
init
和apply
方法将变量作为嵌套字典返回,字典的键是字符串,叶节点是 JAX 数组。在顶层,每个键对应一个变量集合。在每个集合内部,嵌套字典结构对应于Module
层次结构。变量字典是不可变的,因此实际上只是变量所处的状态的快照。当再次调用apply
时,变量字典将作为参数传递。这样,变量就会处于与上一次init
/ apply
调用结束时相同的状态。
注意
模块字段使用 field_name: TypeHint 语法声明(与 dataclasses 相同)。如果没有类型提示,属性将被视为类的静态属性。如果您无法指定类型,可以使用typing.Any
作为通配符类型。
紧凑模块#
Linen 提供了一个替代的 API,可以更紧凑地定义模块。这对于模块仅由一个使用参数和/或子模块的方法组成的常见情况特别有用。使用紧凑 API,可以将 MLP 重写如下:
class CompactMLP(nn.Module):
hidden_size: int
out_size: int
@nn.compact
def __call__(self, x):
a = nn.Dense(self.hidden_size)(x)
h = nn.relu(a)
return nn.Dense(self.out_size)(h)
紧凑的Module
在精神上类似于函数。它提供了一种简洁的表示法,并限制了对函数的输入和返回值的外部交互。在这种情况下,简洁的表示法可能使其他人更容易理解模块的作用。无需在setup
和__call__
方法之间来回跳转以了解子模块的作用。相反,只需从上到下阅读一次__call__
方法,就可以获得一个简洁的概述。如果您正在实现具有许多超参数的复杂模块,这将产生显著的影响。请参阅setup 或 compact,以了解如何选择 setup 和 compact 的实用指南。
在行内定义子模块和/或变量的另一个好处是,您可以在构造变量时向方法添加参数。最常见的例子是使用形状信息来确定参数的形状,如下所示:
class CompactScaledMLP(nn.Module):
hidden_size: int
out_size: int
@nn.compact
def __call__(self, x):
scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:])
x *= scale[None]
a = nn.Dense(self.hidden_size)(x)
h = nn.relu(a)
return nn.Dense(self.out_size)(h)
许多标准 Linen 模块(如nn.Dense
)已经使用形状推断,以避免需要指定输入形状(如 Dense 层的输入特征数量)。
紧凑控制流#
定义子模块的顺序决定了子模块的名称(如果未显式提供,则使用传递给模块构造函数的name=
关键字参数)。由于name
决定了参数如何映射到子模块,因此在将控制流与自动生成的名称混合使用时,您必须小心。使用控制流可以更改顺序或完全删除某些子模块。这在子模块应该仅根据某些构造参数存在时很有用。但是,当控制流依赖于模块的输入参数时,您应该小心。例如,以下模块将无法正常工作
class WrongModule(nn.Module):
@nn.compact
def __call__(self, x, mode):
if mode == "encode":
return nn.Dense(features=8)(x)
elif mode == "decode":
return nn.Dense(features=4)(x)
上面的模块将被破坏,因为编码器或解码器路径将构建一个名为“Dense_0”的模块。这意味着这两个模块将共享参数,而这并非我们想要的结果。实际上,这两个模块不能共享参数,因为它们各自具有不同的特征数量。
- 这个问题可以通过多种方式解决
提供显式名称
在
setup
中创建模块或将构造函数移出控制流。
后者是按如下方式完成的
class CorrectModule(nn.Module):
@nn.compact
def __call__(self, x, mode):
encoder = nn.Dense(8)
decoder = nn.Dense(4)
if mode == "encode":
return encoder(x)
elif mode == "decode":
return decoder(x)
在上面的示例中,构造顺序是固定的。构造完成后,子模块可以以任意顺序使用。
注意
紧凑模块与React hooks非常相似。
顶级模块#
当在“顶级”创建模块实例时,它将处于“未绑定”状态 - 也就是说,它没有附加任何变量。“顶级”意味着它不是在另一个模块类内部作为子模块构建的。除了调用init
和apply
之外,你无法对未绑定模块做太多事情。还要注意,未绑定模块上不会调用setup
,因此你只能访问构造参数。请参考未来工作部分了解这在将来如何变化。
为什么顶级模块始终未绑定?#
当我们调用apply
时,将创建顶级模块的副本,该副本实际上将保存变量和PRNG序列。这种有状态的“绑定”克隆仅在我们执行apply方法时存在。这样做的原因是,如果你创建了一个有状态的对象并在apply函数返回之前销毁它,那么apply
函数本身就像一个纯函数。纯函数有两个约束
如果你输入相同的参数,它将返回相同的输出
它不会改变函数外部的任何内容。这意味着你不能操作纯函数外部可访问的有状态对象。
纯函数有很多优点,但在使用JAX时,它们通常是必不可少的。例如,大多数代码需要使用jax.jit
进行编译才能快速执行,并且一旦你创建了模块,你可能希望使用jax.grad
优化其参数。但是,这些API需要纯函数,并且不能直接在有状态的绑定Module
实例上运行。此外,纯函数允许灵活地与其他库互操作。例如,我们推荐使用Optax来优化参数。Optax中的优化器需要并返回JAX数组的PyTree来进行优化,就像Linen模块的apply
函数一样。
克隆#
为了使这种方法可靠地工作,我们需要定义明确的克隆行为。Flax不依赖于像Python的deepcopy
那样复杂的嵌套克隆过程,而是强制要求Module
完全由其构造参数定义。因此,克隆模块简化为使用其原始构造参数调用构造函数。由于Module
充当不可变数据类,因此构造参数将直接映射到实例属性。在setup
或__post_init__
中计算的非构造属性也应仅依赖于构造参数,以确保定义明确的克隆。
设置#
该setup
方法通常用作普通Python类中的构造函数钩子(__init__
)。但是,对于更高级的用例,最好认识到它与构造函数并不完全相同。
setup
仅在模块绑定后才会被调用。通常情况下,这不是问题,因为大多数模块(几乎)立即绑定(作为init
和apply
的一部分)。在setup
内部,当子模块被分配给属性时,它们将被绑定。在nn.compact
装饰的方法内部,子模块在构造时立即被绑定。如上一节所述,顶级模块永远不会绑定,因此在构造时不会调用setup。这意味着你无法从未绑定的顶级模块访问在setup中分配的属性。
class TopLevelAccess(nn.Module):
def setup(self):
self.foo = nn.Dense(2)
mdl = TopLevelAccess()
assert not hasattr(mdl, "foo") # foo is not defined because setup is not called
该setup
方法不是在Module
绑定后立即被调用,而是在你与Module
实例交互时才被调用(例如:调用方法或访问属性)。这不会影响Module
的行为,但延迟执行有时会影响调试期间的日志语句和堆栈跟踪。关于函数化的部分将解释为什么首先需要setup
是延迟的。
函数化#
到目前为止,我们有一个纯apply
函数,该函数通常使用一些JAX转换进行转换,并且在apply
内部,我们有一个有状态的模块实例可以使用。换句话说:在模块外部,我们处于函数世界中,在那里我们拥有JAX的函数转换功能,而在模块内部,我们拥有Flax的有状态变量和PRNG序列的功能,并且apply
方法是我们在这两个世界之间的桥梁。
但是,如果我们想在模块内部使用JAX转换怎么办?答案是函数化。
这个过程本身很繁琐且容易出错,但由Flax在内部处理。从高层次上讲,我们可以将其总结如下。对于在模块中定义的方法fn
收集应该在JAX转换内部可用的模块的狀態(变量和PRNG序列),并对其进行快照。
使用原始参数和收集的狀態调用JAX转换。然后在转换内部
解压缩状态并重新创建模块
调用用户代码
fn
收集更新的变量和rng,并将其与
fn
的原始返回值一起返回
使用从转换返回的更新状态更新原始状态。
可以在Lifted Transformation设计说明中找到有关函数化和提升的更深入的解释。
实际后果#
在大多数情况下,函数化是自动为你处理的。但是,你必须考虑一些约束条件。最重要的是,Flax只处理有状态的基元(Linen变量和RNG),而不是任意的有状态Python代码。最重要的是:你不能关闭有状态对象和Module
对象,因为它们对Flax的内部机制(以及JAX本身)是不可见的。
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
dense = nn.Dense(x.shape[-1])
fn = lambda x: dense(x) + 1
# simply calling inner works fine
# return self.inner(x, fn)
# but applying a transformation doesn't:
vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
return vmap_inner(self, x, fn)
def inner(self, x, fn):
for i in range(3):
x = fn(x)
return x
这里inner
接受一个函数,该函数关闭了一个模块实例。在这个例子中,这很好,因为我们没有使用提升的转换来转换内部方法。大多数方法都没有被转换,但了解如何使模块方法可转换是很有用的。
可转换性的主要障碍是JAX无法识别的类型。JAX只理解Pytree参数;即(Jax)numpy ndarray和Python数字/布尔值的任意嵌套Python容器(字典、列表、元组)。Flax允许使用flax.struct API定义与Pytree兼容的数据类。
函数闭包是意外地从转换中隐藏JAX数组或Linen模块的最常见方式。但是,如果你想传递与JAX和Linen转换兼容的闭包,则有一个简单的解决方法
class Partial(flax.struct.PyTreeNode):
fn: Callable = flax.struct.field(pytree_node=False)
args: Iterable[Any]
def __call__(self, *args, **kwargs):
return self.fn(*(tuple(self.args) + args), **kwargs)
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
dense = nn.Dense(x.shape[-1])
fn = lambda mdl, x: mdl(x) + 1
vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
return vmap_inner(self, x, Partial(fn, [dense]))
def inner(self, x, fn):
for i in range(3):
x = fn(x)
return x
这里,闭包使用Flax数据类实现。该函数本身用flax.struct.field(pytree_node=False)
进行注释,以指示它不包含JAX数组或Linen模块。另一方面,部分应用的args
被视为pytree容器。我们将闭包重写为使用Partial。现在,内部方法可以使用提升的转换进行转换。
未来工作#
未绑定模块的设置#
目前的模块抽象在构造后初始化字段方面特别严格。在当前的模块 API 中,setup
方法是初始化模块实例字段的地方。因为 setup
仅在绑定模块上调用,所以完整的模块 API 在 setup
内部可用,包括变量声明。然而,很多时候我们实际上并不需要任何有状态的 API 来初始化字段。事实上,最常见的情况是,我们只是想声明一个子模块。更重要的是,检查子模块以进行调试或部分运行模型通常很有用。例如
class AutoEncoder(nn.Module):
def setup(self):
self.encoder = Encoder(...)
self.decoder = Decoder(...)
假设我们想要使用 auto_encoder.decoder.apply(decoder_variables, x) 调用解码器。使用当前的 setup API,这行不通,因为我们必须先绑定变量才能调用 setup 并定义解码器属性。当然,我们可以使用与 setup 中相同的属性手动构造解码器模块,但这在许多情况下并不理想。
有两种可能的解决方案可以使这种用例更符合人体工程学。首先,可以使 setup 在构造后立即运行,然后再绑定。这意味着你仍然可以创建子模块,但不能再定义或操作变量。因此,这将是一个重大更改,它需要一个新的 API 来延迟定义变量。
或者,可以引入一个额外的特殊方法,该方法在模块构造后立即运行,并在绑定之前运行。在这种情况下,setup
方法将保留其原始语义。