JAX风格NNX转换#

  • 作者:Cristian Garcia,Anselm Levskaya

  • 日期:2024年6月

  • FLIP PR:#4107

  • 状态:实施中

动机#

NNX允许用户在顶级使用模块,因为它们是急切初始化并包含自身状态。这自然会导致用户希望将它们与转换一起使用,并很快开始使用NNX转换。由于NNX模块类似于PyTree,因为它们包含数组,因此新用户通常会尝试应用JAX约定,例如

@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
  ...

但是,这可能会产生误导。目前,NNX转换遵循Linen的约定,将输入模块视为一个单元(所有模块一起拆分以保留共享引用),并提供用于单独转换该状态的API。前面的示例实际上转换为

# this is what is really happening
@nnx.vmap(in_axes=(IGNORE, IGNORE), state_axes={BatchStat: None, ...: 0})
def f(m1: Module, m2: Module):
  ...

请注意,IGNORE不是真正的符号,而是表示此处放置的任何值都不会影响结果,因为模块将被空的PyTree占位符替换(类似于None)。state_axes参数控制状态如何通过高级Filter到其所需轴的映射进行向量化。在此示例中,...(省略号)是接受所有内容的过滤器,因此默认情况下所有状态都在第0轴上向量化。

为了表达他们最初的意图,用户必须求助于更复杂的自定义过滤器,这些过滤器猜测单体中每个模块的索引。虽然在简单的情况下这很简单,但用户通常需要计算索引(模块按jax.tree.leavesargs上的顺序出现)

select_m1 = lambda path, value: path[0] == 0
select_m2 = lambda path, value: path[0] == 1

# To select modules individually, you must create a filter (which can be tricky)
@nnx.vmap(state_axes={select_m1: 1, select_m2: 0})
def f(m1: Module, m2: Module):
  ...

如果JAX约定“Just Worked™”会怎样?#

此提案旨在使NNX转换与用户基于其JAX经验的期望保持一致,使语法尽可能直观地工作。原始示例将如同m1m2分别在轴10上向量化的PyTree一样工作。

@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
  ...

这种方法的主要优势在于,对于vmapscan,我们可以消除state_axessplit_rngs参数,仅依靠in_axes API。仅此语法可能足以满足80-90%的用例,因为用户倾向于以可预测的方式管理状态。

Lift符号#

为了能够在每个模块内进行更细粒度的状态控制,我们引入了Lift API。通过在树前缀的位置使用包含状态过滤器的特殊类型,现在可以结构化地进行状态提升。这允许将不同的过滤器应用于参数中的不同模块,而无需使用复杂的基于路径的过滤器。理想情况下,每个转换都将支持自己的Lift类型,通过现有的JAX API添加所需的行为。

例如,在vmap中,我们可以允许StateAxes实例(vmap的Lift类型)被in/out_axes接受以控制子状态如何通过将状态Filter映射到轴说明符来处理。

state_axes = StateAxes({Param: 1, BatchStat: None})

@nnx.vmap(in_axes=(state_axes, 0))
def f(m1: Module, m2: Module):
  ...

在这种情况下,m1Param在轴1上向量化,而其BatchStat被广播,并且m2的整个状态在轴0上向量化。

对于nnx.grad,我们可以允许DiffStateargnums参数中使用,以指定要微分参数的位置和指定模块的可微分状态的过滤器。

grads = nnx.grad(loss_fn, argnums=(DiffState(0, LoRAParam),))(model, x, y)

Rng处理#

为了简化RNG状态处理,我们建议删除vmapscan中单独的split_rngs参数。相反,我们建议引入一个新的nnx.split_rngs API,它将在转换之前和之后管理RNG处理。这种方法为用户提供了更明确的控制,并且与JAX转换行为更好地保持一致。

一致的别名#

为了确保遵循引用语义的对象的转换的正确性,我们必须对所有引用的别名强制执行一致的提升/降低规范。转换必须遵守两个规则

  1. 引用的所有别名都必须接收**完全相同**的提升/降低规范。

  2. 在转换函数的输出上不允许捕获引用。

例如

@nnx.vmap(in_axes=(m1_axes, m2_axes, m1_axes), out_axes=m2_axes)
def f(m1, m2, m1_alias):
  return m2

m2 = f(m1, m2, m1)

这里,m1有两个输入别名,因为它作为第一个和第三个输入传递给f,但这是可以接受的,因为m1_axesin_axes中都分配给了它。m2作为第二个输入传递并具有输出别名,这也是可以接受的,因为m2_axesin_axesout_axes中都分配了。

让我们检查一些根据这些标准应被拒绝的程序示例

不一致的输入别名#

考虑一个具有两个参数m1m2的函数,分别在轴01上向量化。将同一个模块作为这两个参数传递将是不一致的。

@nnx.vmap(in_axes=(0, 1))
def f(m1: Module, m2: Module):
  ...

f(m, m)  # This should be rejected

不一致的输入/输出别名#

现在考虑在vmap下具有in_axes=0out_axes=1的恒等函数g。在JAX中,这将导致转置输入中的数组。

@nnx.vmap(in_axes=0, out_axes=1)
def g(m: Module):
  return m

虽然这看起来是正确的,但在NNX中这种行为没有明确定义,因为共享的可变引用充当辅助输出。在后台,g被转换为一个函数,该函数将输入作为额外的第一个输出,并且out_axes为该输出设置为与in_axes相同的值。

@nnx.vmap(in_axes=0, out_axes=(0, 1))
def g_real(m: Module):
  return m, m

此返回结构揭示了一个不一致之处:我们试图使用out_axes=0out_axes=1都降低m

嵌套结构中不一致的别名#

类似的问题可能出现在不太明显的情况下,例如当m包含在另一个结构中时。

@nnx.vmap(in_axes=0, out_axes=1)
def f(m: Module):
  return SomeModule(m)

这意味着我们必须遍历输入和输出的整个图形以检查一致的赋值。当传递具有不同规范的共享引用输入/输出时,也会发生同样的问题。

shared = Shared()
m1, m2 = Foo(shared), Foo(shared)

@nnx.vmap(in_axes=(0, 1))
def f(m1, m2):  # shared is passed through both
  ...

捕获的模块不能作为输出#

最后,让我们考虑第二个一致别名规则,该规则规定捕获的模块不能作为输出。这里的主要问题是NNX需要将所有输入引用一起拆分以跟踪更改,但捕获的模块绕过了此过程。将它们视为新引用会导致**隐式克隆**。

m = SomeModule()

@nnx.vmap(out_axes=0, axis_size=5)
def f():
  return m

assert m is not f()  # implicit cloning

为了保留引用标识,我们必须不允许捕获的模块作为输出。在实践中,我们可以使用用于限制来自不同级别的模块上的状态更新的跟踪级别上下文机制来检测捕获的模块。

总结#

在本文件中,我们

  • 讨论了当前实现中的一些问题,这些问题使得 JAX 用户难以理解。

  • 提议重构 NNX 变换,以允许用户在与对象交互时使用常规 JAX 语义,从而删除由 NNX 变换引入的额外参数。

  • 引入了在 JAX API 中使用 Lift 类型来弥补 NNX 对象中缺乏“前缀”概念的不足,从而实现模块子状态的独立提升。

  • 提议了一个新的nnx.split_rngs API 来替换vmapscan中的split_rngs参数,使 RNG 处理成为显式操作,并为用户提供更多控制权。

  • 分析了由共享可变引用别名导致的边缘情况,并提议对所有具有输入语义的变换实施**一致别名**。