JAX风格NNX转换#
作者:Cristian Garcia,Anselm Levskaya
日期:2024年6月
FLIP PR:#4107
状态:实施中
动机#
NNX允许用户在顶级使用模块,因为它们是急切初始化并包含自身状态。这自然会导致用户希望将它们与转换一起使用,并很快开始使用NNX转换。由于NNX模块类似于PyTree,因为它们包含数组,因此新用户通常会尝试应用JAX约定,例如
@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
...
但是,这可能会产生误导。目前,NNX转换遵循Linen的约定,将输入模块视为一个单元(所有模块一起拆分以保留共享引用),并提供用于单独转换该状态的API。前面的示例实际上转换为
# this is what is really happening
@nnx.vmap(in_axes=(IGNORE, IGNORE), state_axes={BatchStat: None, ...: 0})
def f(m1: Module, m2: Module):
...
请注意,IGNORE
不是真正的符号,而是表示此处放置的任何值都不会影响结果,因为模块将被空的PyTree占位符替换(类似于None
)。state_axes
参数控制状态如何通过高级Filter
到其所需轴的映射进行向量化。在此示例中,...
(省略号)是接受所有内容的过滤器,因此默认情况下所有状态都在第0轴上向量化。
为了表达他们最初的意图,用户必须求助于更复杂的自定义过滤器,这些过滤器猜测单体中每个模块的索引。虽然在简单的情况下这很简单,但用户通常需要计算索引(模块按jax.tree.leaves
在args
上的顺序出现)
select_m1 = lambda path, value: path[0] == 0
select_m2 = lambda path, value: path[0] == 1
# To select modules individually, you must create a filter (which can be tricky)
@nnx.vmap(state_axes={select_m1: 1, select_m2: 0})
def f(m1: Module, m2: Module):
...
如果JAX约定“Just Worked™”会怎样?#
此提案旨在使NNX转换与用户基于其JAX经验的期望保持一致,使语法尽可能直观地工作。原始示例将如同m1
和m2
分别在轴1
和0
上向量化的PyTree一样工作。
@nnx.vmap(in_axes=(1, 0))
def f(m1: Module, m2: Module):
...
这种方法的主要优势在于,对于vmap
和scan
,我们可以消除state_axes
和split_rngs
参数,仅依靠in_axes
API。仅此语法可能足以满足80-90%的用例,因为用户倾向于以可预测的方式管理状态。
Lift符号#
为了能够在每个模块内进行更细粒度的状态控制,我们引入了Lift
API。通过在树前缀的位置使用包含状态过滤器的特殊类型,现在可以结构化地进行状态提升。这允许将不同的过滤器应用于参数中的不同模块,而无需使用复杂的基于路径的过滤器。理想情况下,每个转换都将支持自己的Lift类型,通过现有的JAX API添加所需的行为。
例如,在vmap
中,我们可以允许StateAxes
实例(vmap的Lift类型)被in/out_axes
接受以控制子状态如何通过将状态Filter
映射到轴说明符来处理。
state_axes = StateAxes({Param: 1, BatchStat: None})
@nnx.vmap(in_axes=(state_axes, 0))
def f(m1: Module, m2: Module):
...
在这种情况下,m1
的Param
在轴1
上向量化,而其BatchStat
被广播,并且m2
的整个状态在轴0
上向量化。
对于nnx.grad
,我们可以允许DiffState
在argnums
参数中使用,以指定要微分参数的位置和指定模块的可微分状态的过滤器。
grads = nnx.grad(loss_fn, argnums=(DiffState(0, LoRAParam),))(model, x, y)
Rng处理#
为了简化RNG状态处理,我们建议删除vmap
和scan
中单独的split_rngs
参数。相反,我们建议引入一个新的nnx.split_rngs
API,它将在转换之前和之后管理RNG处理。这种方法为用户提供了更明确的控制,并且与JAX转换行为更好地保持一致。
一致的别名#
为了确保遵循引用语义的对象的转换的正确性,我们必须对所有引用的别名强制执行一致的提升/降低规范。转换必须遵守两个规则
引用的所有别名都必须接收**完全相同**的提升/降低规范。
在转换函数的输出上不允许捕获引用。
例如
@nnx.vmap(in_axes=(m1_axes, m2_axes, m1_axes), out_axes=m2_axes)
def f(m1, m2, m1_alias):
return m2
m2 = f(m1, m2, m1)
这里,m1
有两个输入别名,因为它作为第一个和第三个输入传递给f
,但这是可以接受的,因为m1_axes
在in_axes
中都分配给了它。m2
作为第二个输入传递并具有输出别名,这也是可以接受的,因为m2_axes
在in_axes
和out_axes
中都分配了。
让我们检查一些根据这些标准应被拒绝的程序示例
不一致的输入别名#
考虑一个具有两个参数m1
和m2
的函数,分别在轴0
和1
上向量化。将同一个模块作为这两个参数传递将是不一致的。
@nnx.vmap(in_axes=(0, 1))
def f(m1: Module, m2: Module):
...
f(m, m) # This should be rejected
不一致的输入/输出别名#
现在考虑在vmap
下具有in_axes=0
和out_axes=1
的恒等函数g
。在JAX中,这将导致转置输入中的数组。
@nnx.vmap(in_axes=0, out_axes=1)
def g(m: Module):
return m
虽然这看起来是正确的,但在NNX中这种行为没有明确定义,因为共享的可变引用充当辅助输出。在后台,g
被转换为一个函数,该函数将输入作为额外的第一个输出,并且out_axes
为该输出设置为与in_axes
相同的值。
@nnx.vmap(in_axes=0, out_axes=(0, 1))
def g_real(m: Module):
return m, m
此返回结构揭示了一个不一致之处:我们试图使用out_axes=0
和out_axes=1
都降低m
。
嵌套结构中不一致的别名#
类似的问题可能出现在不太明显的情况下,例如当m
包含在另一个结构中时。
@nnx.vmap(in_axes=0, out_axes=1)
def f(m: Module):
return SomeModule(m)
这意味着我们必须遍历输入和输出的整个图形以检查一致的赋值。当传递具有不同规范的共享引用输入/输出时,也会发生同样的问题。
shared = Shared()
m1, m2 = Foo(shared), Foo(shared)
@nnx.vmap(in_axes=(0, 1))
def f(m1, m2): # shared is passed through both
...
捕获的模块不能作为输出#
最后,让我们考虑第二个一致别名规则,该规则规定捕获的模块不能作为输出。这里的主要问题是NNX需要将所有输入引用一起拆分以跟踪更改,但捕获的模块绕过了此过程。将它们视为新引用会导致**隐式克隆**。
m = SomeModule()
@nnx.vmap(out_axes=0, axis_size=5)
def f():
return m
assert m is not f() # implicit cloning
为了保留引用标识,我们必须不允许捕获的模块作为输出。在实践中,我们可以使用用于限制来自不同级别的模块上的状态更新的跟踪级别上下文机制来检测捕获的模块。
总结#
在本文件中,我们
讨论了当前实现中的一些问题,这些问题使得 JAX 用户难以理解。
提议重构 NNX 变换,以允许用户在与对象交互时使用常规 JAX 语义,从而删除由 NNX 变换引入的额外参数。
引入了在 JAX API 中使用 Lift 类型来弥补 NNX 对象中缺乏“前缀”概念的不足,从而实现模块子状态的独立提升。
提议了一个新的
nnx.split_rngs
API 来替换vmap
和scan
中的split_rngs
参数,使 RNG 处理成为显式操作,并为用户提供更多控制权。分析了由共享可变引用别名导致的边缘情况,并提议对所有具有输入语义的变换实施**一致别名**。