Flax 模块生命周期#

这份设计说明针对的是已经熟悉 Flax Linen 模块但希望进一步了解抽象背后设计原则的用户。这份说明将帮助您深入理解模块 API 建立在的假设和保证。如果您还没有实际使用过模块,请查看快速入门指南

Flax Linen 模块在 Flax 核心之上提供了一个 Pythonic 抽象。模块 抽象允许您创建在 JAX 之上具有状态、参数和随机性的类。这是一份有关Module 类的设计和行为的实用指南。最后,您应该可以放心地走出常规,以新方式使用模块。

概述#

定义#

让我们从模块生命周期的概述开始。首先,定义一个简单的模块

class MLP(nn.Module):
  # 1. Attribute annotations
  hidden_size: int
  out_size: int

  # 2. The ``setup`` method
  def setup(self):
    self.hidden = nn.Dense(self.hidden_size)
    self.out = nn.Dense(self.out_size)

  # 3. User methods
  def __call__(self, x):
    a = self.hidden(x)
    h = nn.relu(a)
    return self.out(h)

这个模块由以下部分组成:

  1. **属性注释**,定义为dataclass 字段。这些注释会自动定义构造函数。

  2. **``setup`` 方法**,它创建子模块并将它们分配给属性。

  3. **用户方法**。按照惯例,大多数模块只有一个__call__ 方法,但您可以定义多个方法或使用不同的方法名称。

构造/初始化#

现在,我们要构造和使用MLP 模块

mlp = MLP(hidden_size=5, out_size=3)
x = jax.numpy.ones((1, 2))
variables = mlp.init(random.key(0), x)
y = mlp.apply(variables, x)

首先,我们构造一个MLP 实例,并传入构造属性。请注意,这里的构造与您在不熟悉函数式编程模式的情况下可能期望的不同。MLP 构造函数实际上并没有创建变量或任何内部状态。最好将其视为包含功能但没有数据的模块的规范或模板。

让我们仔细看看初始化。令人惊讶的是,Flax 中实际上没有单独的初始化路径。调用init 只是apply 的一个特例,您也可以将其写成

# equivalent to: variables = mlp.init(random.key(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True)

因此,init 只是apply 的一个包装器,其中

  1. 我们调用没有任何初始变量(空字典)的模块。

  2. 始终会传入名为"params" 的 PRNG 生成器,用于使用参数初始化函数随机初始化参数。

  3. 所有变量集合都设置为可变的 (mutable=True)。当集合可变时,可以更新现有变量并创建新变量。因此,在init 内部,可以在任何变量集合中初始化变量,并且它们都将被添加到返回的变量字典中。

生命周期#

既然您已经了解了initapply 的一个特例,那么让我们更详细地了解.apply(...)。实际上,模块的大部分复杂性都位于apply 方法中。“模块生命周期”包括构造和apply 模块。我们可以将模块生命周期总结如下:

  1. 我们构造mlp = MLP(hidden_size=5, out_size=3),因此mlp.hidden_size=5mlp.out_size=3

  2. 然后,调用mlp.apply,它会

    1. 克隆mlp,我们将其称为mlp_copy

    2. 调用mlp_copy.setup()

    3. 返回mlp_copy.__call__() 的输出,并可选地返回使用关键字参数mutable= 指定为可变的变量集合。

请注意,生命周期包括克隆模块实例。这样做是为了确保apply 可以被视为一个纯函数(即,如果传入相同的参数,它将返回相同的输出)。您将在后面的顶层模块 部分中详细了解这一点。

变量#

术语“变量”在编程和数学中很常见。但是,重要的是要了解 JAX 和 Flax 中变量的含义。在 Flax 模块内部,变量 的行为与您对 Python 的期望一致。它们被初始化一次,被读取,并且可能还会定期更新。但是,JAX 没有变量的概念。相反,值存储在类似于 NumPy 数组的数组中 - 它们有一个重要的区别:它们是不可变的。

initapply 方法将变量作为嵌套字典返回,字典的键是字符串,叶节点是 JAX 数组。在顶层,每个键对应一个变量集合。在每个集合内部,嵌套字典结构对应于Module 层次结构。变量字典是不可变的,因此实际上只是变量所处的状态的快照。当再次调用apply 时,变量字典将作为参数传递。这样,变量就会处于与上一次init / apply 调用结束时相同的状态。

注意

模块字段使用 field_name: TypeHint 语法声明(与 dataclasses 相同)。如果没有类型提示,属性将被视为类的静态属性。如果您无法指定类型,可以使用typing.Any 作为通配符类型。

紧凑模块#

Linen 提供了一个替代的 API,可以更紧凑地定义模块。这对于模块仅由一个使用参数和/或子模块的方法组成的常见情况特别有用。使用紧凑 API,可以将 MLP 重写如下:

class CompactMLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x):
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)

紧凑的Module 在精神上类似于函数。它提供了一种简洁的表示法,并限制了对函数的输入和返回值的外部交互。在这种情况下,简洁的表示法可能使其他人更容易理解模块的作用。无需在setup__call__ 方法之间来回跳转以了解子模块的作用。相反,只需从上到下阅读一次__call__ 方法,就可以获得一个简洁的概述。如果您正在实现具有许多超参数的复杂模块,这将产生显著的影响。请参阅setup 或 compact,以了解如何选择 setup 和 compact 的实用指南。

在行内定义子模块和/或变量的另一个好处是,您可以在构造变量时向方法添加参数。最常见的例子是使用形状信息来确定参数的形状,如下所示:

class CompactScaledMLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x):
    scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:])
    x *= scale[None]
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)

许多标准 Linen 模块(如nn.Dense)已经使用形状推断,以避免需要指定输入形状(如 Dense 层的输入特征数量)。

紧凑控制流#

定义子模块的顺序决定了子模块的名称(如果未显式提供,则使用传递给模块构造函数的name= 关键字参数)。由于name 决定了参数如何映射到子模块,因此在将控制流与自动生成的名称混合使用时,您必须小心。使用控制流可以更改顺序或完全删除某些子模块。这在子模块应该仅根据某些构造参数存在时很有用。但是,当控制流依赖于模块的输入参数时,您应该小心。例如,以下模块将无法正常工作

class WrongModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    if mode == "encode":
      return nn.Dense(features=8)(x)
    elif mode == "decode":
      return nn.Dense(features=4)(x)

上面的模块将被破坏,因为编码器或解码器路径将构建一个名为“Dense_0”的模块。这意味着这两个模块将共享参数,而这并非我们想要的结果。实际上,这两个模块不能共享参数,因为它们各自具有不同的特征数量。

这个问题可以通过多种方式解决
  • 提供显式名称

  • setup中创建模块

  • 或将构造函数移出控制流。

后者是按如下方式完成的

class CorrectModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    encoder = nn.Dense(8)
    decoder = nn.Dense(4)
    if mode == "encode":
      return encoder(x)
    elif mode == "decode":
      return decoder(x)

在上面的示例中,构造顺序是固定的。构造完成后,子模块可以以任意顺序使用。

注意

紧凑模块与React hooks非常相似。

顶级模块#

当在“顶级”创建模块实例时,它将处于“未绑定”状态 - 也就是说,它没有附加任何变量。“顶级”意味着它不是在另一个模块类内部作为子模块构建的。除了调用initapply之外,你无法对未绑定模块做太多事情。还要注意,未绑定模块上不会调用setup,因此你只能访问构造参数。请参考未来工作部分了解这在将来如何变化。

为什么顶级模块始终未绑定?#

当我们调用apply时,将创建顶级模块的副本,该副本实际上将保存变量和PRNG序列。这种有状态的“绑定”克隆仅在我们执行apply方法时存在。这样做的原因是,如果你创建了一个有状态的对象并在apply函数返回之前销毁它,那么apply函数本身就像一个纯函数。纯函数有两个约束

  1. 如果你输入相同的参数,它将返回相同的输出

  2. 它不会改变函数外部的任何内容。这意味着你不能操作纯函数外部可访问的有状态对象。

纯函数有很多优点,但在使用JAX时,它们通常是必不可少的。例如,大多数代码需要使用jax.jit进行编译才能快速执行,并且一旦你创建了模块,你可能希望使用jax.grad优化其参数。但是,这些API需要纯函数,并且不能直接在有状态的绑定Module实例上运行。此外,纯函数允许灵活地与其他库互操作。例如,我们推荐使用Optax来优化参数。Optax中的优化器需要并返回JAX数组的PyTree来进行优化,就像Linen模块的apply函数一样。

克隆#

为了使这种方法可靠地工作,我们需要定义明确的克隆行为。Flax不依赖于像Python的deepcopy那样复杂的嵌套克隆过程,而是强制要求Module完全由其构造参数定义。因此,克隆模块简化为使用其原始构造参数调用构造函数。由于Module充当不可变数据类,因此构造参数将直接映射到实例属性。在setup__post_init__中计算的非构造属性也应仅依赖于构造参数,以确保定义明确的克隆。

绑定#

有时,在不将代码包装在函数中的情况下,拥有一个绑定的顶级模块很有用。例如:在Jupyter笔记本中与模块交互。该bind方法返回一个绑定的克隆,其生命周期不受限制。这样做的缺点是你不能将其与JAX转换组合使用,也不能将其集成到期望无状态代码的普通JAX代码库中。例如,Optax可以优化参数的Pytree,但它不能直接优化使用.bind创建的绑定Module实例(因为这不是Pytree)。因此,你不能将bindAPI与像Optax这样的函数优化器API组合使用。

设置#

setup方法通常用作普通Python类中的构造函数钩子(__init__)。但是,对于更高级的用例,最好认识到它与构造函数并不完全相同。

setup仅在模块绑定后才会被调用。通常情况下,这不是问题,因为大多数模块(几乎)立即绑定(作为initapply的一部分)。在setup内部,当子模块被分配给属性时,它们将被绑定。在nn.compact装饰的方法内部,子模块在构造时立即被绑定。如上一节所述,顶级模块永远不会绑定,因此在构造时不会调用setup。这意味着你无法从未绑定的顶级模块访问在setup中分配的属性。

class TopLevelAccess(nn.Module):

  def setup(self):
    self.foo = nn.Dense(2)

mdl = TopLevelAccess()
assert not hasattr(mdl, "foo")  # foo is not defined because setup is not called

setup方法不是在Module绑定后立即被调用,而是在你与Module实例交互时才被调用(例如:调用方法或访问属性)。这不会影响Module的行为,但延迟执行有时会影响调试期间的日志语句和堆栈跟踪。关于函数化的部分将解释为什么首先需要setup是延迟的。

函数化#

到目前为止,我们有一个纯apply函数,该函数通常使用一些JAX转换进行转换,并且在apply内部,我们有一个有状态的模块实例可以使用。换句话说:在模块外部,我们处于函数世界中,在那里我们拥有JAX的函数转换功能,而在模块内部,我们拥有Flax的有状态变量和PRNG序列的功能,并且apply方法是我们在这两个世界之间的桥梁。

但是,如果我们想在模块内部使用JAX转换怎么办?答案是函数化。

这个过程本身很繁琐且容易出错,但由Flax在内部处理。从高层次上讲,我们可以将其总结如下。对于在模块中定义的方法fn

  1. 收集应该在JAX转换内部可用的模块的狀態(变量和PRNG序列),并对其进行快照。

  2. 使用原始参数和收集的狀態调用JAX转换。然后在转换内部

    1. 解压缩状态并重新创建模块

    2. 调用用户代码fn

    3. 收集更新的变量和rng,并将其与fn的原始返回值一起返回

  3. 使用从转换返回的更新状态更新原始状态。

可以在Lifted Transformation设计说明中找到有关函数化和提升的更深入的解释。

实际后果#

在大多数情况下,函数化是自动为你处理的。但是,你必须考虑一些约束条件。最重要的是,Flax只处理有状态的基元(Linen变量和RNG),而不是任意的有状态Python代码。最重要的是:你不能关闭有状态对象和Module对象,因为它们对Flax的内部机制(以及JAX本身)是不可见的。

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda x: dense(x) + 1
    # simply calling inner works fine
    # return self.inner(x, fn)
    # but applying a transformation doesn't:
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, fn)

  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x

这里inner接受一个函数,该函数关闭了一个模块实例。在这个例子中,这很好,因为我们没有使用提升的转换来转换内部方法。大多数方法都没有被转换,但了解如何使模块方法可转换是很有用的。

可转换性的主要障碍是JAX无法识别的类型。JAX只理解Pytree参数;即(Jax)numpy ndarray和Python数字/布尔值的任意嵌套Python容器(字典、列表、元组)。Flax允许使用flax.struct API定义与Pytree兼容的数据类。

函数闭包是意外地从转换中隐藏JAX数组或Linen模块的最常见方式。但是,如果你想传递与JAX和Linen转换兼容的闭包,则有一个简单的解决方法

class Partial(flax.struct.PyTreeNode):
  fn: Callable = flax.struct.field(pytree_node=False)
  args: Iterable[Any]

  def __call__(self, *args, **kwargs):
    return self.fn(*(tuple(self.args) + args), **kwargs)

class Foo(nn.Module):

  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda mdl, x: mdl(x) + 1
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, Partial(fn, [dense]))

  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x

这里,闭包使用Flax数据类实现。该函数本身用flax.struct.field(pytree_node=False)进行注释,以指示它不包含JAX数组或Linen模块。另一方面,部分应用的args被视为pytree容器。我们将闭包重写为使用Partial。现在,内部方法可以使用提升的转换进行转换。

未来工作#

未绑定模块的设置#

目前的模块抽象在构造后初始化字段方面特别严格。在当前的模块 API 中,setup 方法是初始化模块实例字段的地方。因为 setup 仅在绑定模块上调用,所以完整的模块 API 在 setup 内部可用,包括变量声明。然而,很多时候我们实际上并不需要任何有状态的 API 来初始化字段。事实上,最常见的情况是,我们只是想声明一个子模块。更重要的是,检查子模块以进行调试或部分运行模型通常很有用。例如

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder(...)
    self.decoder = Decoder(...)

假设我们想要使用 auto_encoder.decoder.apply(decoder_variables, x) 调用解码器。使用当前的 setup API,这行不通,因为我们必须先绑定变量才能调用 setup 并定义解码器属性。当然,我们可以使用与 setup 中相同的属性手动构造解码器模块,但这在许多情况下并不理想。

有两种可能的解决方案可以使这种用例更符合人体工程学。首先,可以使 setup 在构造后立即运行,然后再绑定。这意味着你仍然可以创建子模块,但不能再定义或操作变量。因此,这将是一个重大更改,它需要一个新的 API 来延迟定义变量。

或者,可以引入一个额外的特殊方法,该方法在模块构造后立即运行,并在绑定之前运行。在这种情况下,setup 方法将保留其原始语义。