变换#
通常,JAX 变换(transforms)操作 pytrees 的 jax.Array
,并遵循值语义。这对于 Flax NNX 来说是一个挑战,因为它将 nnx.Module
表示为遵循引用语义的常规 Python 对象。为了解决这个问题,Flax NNX 引入了自己的一组变换,扩展了 JAX 变换,允许 nnx.Module
和其他 Flax NNX 对象在变换中传入和传出,同时保留引用语义。
如果您以前使用过 JAX 变换,那么 Flax NNX 变换应该会非常熟悉。它们使用相同的 API,并且在仅处理 jax.Array
的 pytrees 时,其行为类似于 JAX 变换。但是,当处理 Flax NNX 对象时,它们允许为这些对象保留 Python 的引用语义,这包括
保留变换输入和输出中多个对象之间的共享引用。
将变换内部对对象进行的任何状态更改传播到变换外部的对象。
当多个输入和输出之间存在别名时,强制对象如何变换的一致性。
import jax
from jax import numpy as jnp, random
from flax import nnx
在本指南中,nnx.vmap
被用作一个案例研究,以演示 Flax NNX 变换的工作原理。但是,本文档中概述的原则适用于所有变换。
基本示例#
首先,让我们看一个使用 nnx.vmap
将元素级的 vector_dot
函数扩展到处理批量输入的简单示例。我们将定义一个没有方法的 Weights
模块来保存一些参数,这些权重将作为输入传递给 vector_dot
函数以及一些数据。权重和数据都将在轴 0
上批量处理,我们将使用 nnx.vmap
将 vector_dot
应用于每个批次元素,结果将在轴 1
上批量处理。
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))
def vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ weights.kernel + weights.bias
y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)
print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)
请注意,in_axes
与 Weights
模块自然地交互,将其视为 jax.Array
的 pytree。也允许使用前缀模式,因此在这种情况下,in_axes=(0, 0)
也会起作用。
对象也允许作为 Flax NNX 变换的输出,这对于变换初始化器很有用。例如,您可以定义一个 create_weights
函数来创建一个单一的 Weights
nnx.Module
,并使用 nnx.vmap
创建一个与之前形状相同的 Weights
堆栈。
def create_weights(seed: jax.Array):
return Weights(
kernel=random.uniform(random.key(seed), (2, 3)),
bias=jnp.zeros((3,)),
)
seeds = jnp.arange(10)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)
变换方法#
Python 中的方法只是将实例作为第一个参数的函数,这意味着您可以装饰 Module
和其他 Flax NNX 子类型的方法。例如,我们可以重构前面示例中的 Weights
,并使用 vmap
装饰 __init__
来完成 create_weights
的工作,并添加一个 __call__
方法并使用 @nnx.vmap
装饰它来完成 vector_dot
的工作。
class WeightStack(nnx.Module):
@nnx.vmap
def __init__(self, seed: jax.Array):
self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
self.bias = nnx.Param(jnp.zeros((3,)))
@nnx.vmap(in_axes=0, out_axes=1)
def __call__(self, x: jax.Array):
assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ self.kernel + self.bias
weights = WeightStack(jnp.arange(10))
x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)
print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)
本指南的其余部分将重点介绍变换单个函数。但请注意,所有示例都可以用这种方法风格编写。
状态传播#
到目前为止,我们的函数都是无状态的。但是,当您拥有有状态的函数时,Flax NNX 变换的真正威力就体现出来了,因为它们的主要功能之一是传播状态更改以保留引用语义。让我们通过向 Weights
添加一个 count
属性并在新的 stateful_vector_dot
函数中递增它来更新前面的示例。
class Count(nnx.Variable): pass
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))
def stateful_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
return x @ weights.kernel + weights.bias
y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)
weights.count
Count(
value=Array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int32)
)
运行一次 stateful_vector_dot
后,您验证了 count
属性已正确更新。由于 Weights
已向量化,count
被初始化为 arange(10)
,并且它的所有元素在变换内部都递增了 1
。最重要的是,更新被传播到变换外部的原始 Weights
对象。太棒了!
图更新传播#
JAX 变换将输入视为 jax.Array
的 pytrees,而 Flax NNX 将输入视为 jax.Array
的 pytrees 和 Python 引用,其中引用形成一个图。只要它们是输入本地的(不支持变换内部对全局变量的更新),Flax NNX 的状态传播机制就可以跟踪对对象的任意更新。
这意味着您可以根据需要修改图结构,包括更新现有属性、添加/删除属性、交换属性、在对象之间共享(新的)引用、在对象之间共享 nnx.Variable
等。一切皆有可能!
以下示例演示了在 nnx.vmap
内部对 Weights
对象执行一些任意更新,并验证这些更新是否正确传播到变换外部的原始 Weights
对象。
class Count(nnx.Variable): pass
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))
def crazy_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
y = x @ weights.kernel + weights.bias
weights.some_property = ['a', 2, False] # add attribute
del weights.bias # delete attribute
weights.new_param = weights.kernel # share reference
return y
y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)
nnx.display(weights)
能力越大,责任越大。
- 本叔
虽然此功能非常强大,但必须谨慎使用,因为它可能会与 JAX 对某些变换的底层假设发生冲突。例如,jit
希望输入的结构是稳定的,以便缓存编译后的函数,因此在 nnx.jit
-ed 函数内部更改图结构会导致持续的重新编译和性能下降。另一方面,scan
只允许固定的 carry
结构,因此添加/删除声明为 carry 的子状态会导致错误。
变换子状态(提升类型)#
某些 JAX 变换允许使用 pytree 前缀来指定应该如何变换输入/输出的不同部分。Flax NNX 支持 pytree 结构的 pytree 前缀,但目前它没有图对象的前缀概念。相反,Flax NNX 引入了“提升类型”的概念,该概念允许指定应该如何变换对象的不同子状态。不同的变换支持不同的提升类型,以下是每个 JAX 变换当前支持的 FLax NNX 提升类型列表
提升类型 |
JAX 变换 |
---|---|
|
|
|
|
|
|
注意:* 在编写本文档的此版本时,Flax NNX
shard_map
尚未实现。
为了指定如何在 nnx.vmap
中向量化对象的不同子状态,Flax 团队创建了一个 nnx.StateAxes
。StateAxes
通过 Flax NNX 过滤器将一组子状态映射到其相应的轴,你可以将 nnx.StateAxes
传递给 in_axes
和 out_axes
,就像它/它们是 pytree 前缀一样。
让我们使用之前的 stateful_vector_dot
示例,仅向量化 nnx.Param
变量,并广播 count
变量,这样我们只为所有批次元素保留一个计数。为此,我们将定义一个 nnx.StateAxes
,其中包含一个匹配 nnx.Param
变量的过滤器,并将它们映射到轴 0
,并将所有 Count
变量映射到 None
,并将此 nnx.StateAxes
传递给 in_axes
作为 Weights
对象。
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
weights = Weights(
kernel=random.uniform(random.key(0), (10, 2, 3)),
bias=jnp.zeros((10, 3)),
count=jnp.array(0),
)
x = jax.random.normal(random.key(1), (10, 2))
def stateful_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
return x @ weights.kernel + weights.bias
state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)
weights.count
Count(
value=Array(1, dtype=int32, weak_type=True)
)
在这里,count
现在是一个标量,因为它没有被向量化。另请注意,nnx.StateAxes
只能直接用于 Flax NNX 对象,不能用作对象 pytree 的前缀。
随机状态#
在 Flax NNX 中,随机状态只是一个常规状态。这意味着它存储在需要它的 nnx.Module
中,并且被视为任何其他类型的状态。这是对 Flax Linen 的简化,在 Flax Linen 中,随机状态由单独的机制处理。实际上,nnx.Module
只需要保留对初始化期间传递给它们的 Rngs
对象的引用,并使用它为每个随机操作生成唯一的键。就本指南而言,这意味着随机状态可以像任何其他类型的状态一样进行转换,但我们也需要了解状态的布局方式,以便我们可以正确地转换它。
假设你想稍微改变一下,并将相同的权重应用于批次中的所有元素。但你也想为每个元素添加不同的随机噪声。
为此,你将向 Weights
添加一个 Rngs
属性,该属性从构造期间传递的 seed
键参数创建。此种子键必须事先被 split
,以便你可以成功地对其进行向量化。出于教学原因,你将种子键分配给一个 noise
“流”,并从中进行采样。要向量化 PRNG 状态,你必须配置 nnx.StateAxes
将所有 RngState
(Rngs
中所有变量的基类)映射到轴 0
,并将 nnx.Param
和 Count
映射到 None
。
class Weights(nnx.Module):
def __init__(self, kernel, bias, count, seed):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
self.count = Count(count)
self.rngs = nnx.Rngs(noise=seed)
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
seed=random.split(random.key(0), num=10),
)
x = random.normal(random.key(1), (10, 2))
def noisy_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
y = x @ weights.kernel + weights.bias
return y + random.normal(weights.rngs.noise(), y.shape)
state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
print(jnp.allclose(y1, y2))
nnx.display(weights)
False
由于 Rngs
的状态是就地更新的,并且由 nnx.vmap
自动传播,因此每次调用 noisy_vector_dot
时,我们都会得到不同的结果。
在上面的示例中,你是在构造期间手动拆分随机状态。这没问题,因为它清楚地表明了意图,但它也不允许你在 nnx.vmap
之外使用 Rngs
,因为它的状态始终是拆分的。为了解决这个问题,你可以传递一个未拆分的种子,并在 nnx.vmap
之前使用 nnx.split_rngs
装饰器,以便在每次调用函数之前拆分 RngState
,然后将其“降低”回来,使其变得可用。
weights = Weights(
kernel=random.uniform(random.key(0), (2, 3)),
bias=jnp.zeros((3,)),
count=jnp.array(0),
seed=0,
)
x = random.normal(random.key(1), (10, 2))
state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
@nnx.split_rngs(splits=10)
@nnx.vmap(in_axes=(state_axes, 0))
def noisy_vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
weights.count += 1
y = x @ weights.kernel + weights.bias
return y + random.normal(weights.rngs.noise(), y.shape)
y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)
print(jnp.allclose(y1, y2))
nnx.display(weights)
False
规则和限制#
在本节中,我们将介绍在转换内部使用模块时适用的一些规则和限制。
可变模块不能通过闭包传递#
虽然 Python 允许将对象作为闭包传递给函数,但 Flax NNX 转换通常不支持这样做。原因是,由于模块是可变的,因此很容易将跟踪器捕获到在转换外部创建的模块中,这是 JAX 中的静默错误。为了避免这种情况,Flax NNX 会检查被修改的模块和变量是否作为参数传递给转换后的函数。
例如,如果我们有一个有状态的模块(例如 Counter
),每次调用它时都会递增一个计数器,并且我们尝试将其作为闭包传递给用 nnx.jit
修饰的函数,我们将泄漏跟踪器。但是,Flax NNX 会引发错误,以防止这种情况发生
class Counter(nnx.Module):
def __init__(self):
self.count = nnx.Param(jnp.array(0))
def increment(self):
self.count += jnp.array(1)
counter = Counter()
@nnx.jit
def f(x):
counter.increment()
return 2 * x
try:
y = f(3)
except Exception as e:
print(e)
要解决此问题,请将所有模块作为参数传递给正在转换的函数。在这种情况下,f
应该接受 counter
作为参数。
一致的别名#
在转换中允许引用语义的主要问题是,引用可以在输入和输出之间共享。如果不加以处理,这可能会导致不明确或不一致的行为。在下面的示例中,你有一个单独的 Weights
模块 m
,其引用在 arg1
和 arg2
中的多个位置出现。这里的问题是,你还指定要以轴 0
向量化 arg1
,并以轴 1
向量化 arg2
。这在 JAX 中没问题,因为 pytree 具有引用透明性。但这在 Flax NNX 中是有问题的,因为你尝试以两种不同的方式向量化 m
。Flax NNX 将通过引发错误来强制执行一致性。
class Weights(nnx.Module):
def __init__(self, array: jax.Array):
self.param = nnx.Param(array)
m = Weights(jnp.arange(10))
arg1 = {'a': {'b': m}, 'c': m}
arg2 = [(m, m), m]
@nnx.vmap(in_axes=(0, 1))
def f(arg1, arg2):
...
try:
f(arg1, arg2)
except ValueError as e:
print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
param: 0
param: 0
param: 1
输入和输出之间也可能发生不一致的别名。在下一个示例中,你有一个简单的函数,它接受并立即返回 arg1
。但是,arg1
在输入上以轴 0
向量化,在输出上以轴 1
向量化。正如预期的那样,这有问题,Flax NNX 会引发错误。
@nnx.vmap(in_axes=0, out_axes=1)
def f(arg1):
return arg1
try:
f(arg1)
except ValueError as e:
print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
param: 0
param: 0
param: 1
轴元数据#
Flax NNX Variable
可以保存任意元数据,这些元数据可以通过将其作为关键字参数传递给其构造函数来添加。这通常用于存储 sharding
信息,如 nnx.spmd
API(如 nnx.get_partition_spec
和 nnx.get_named_sharding
)所使用。
但是,在涉及转换时,使此轴相关的信息与轴的实际状态保持同步通常很重要。例如,如果在轴 1
上向量化一个变量,则应在 vmap
或 scan
内部删除位置 1
的 sharding
信息,以反映轴被临时删除的事实。
为了实现这一点,Flax NNX 转换提供了一个非标准的 transform_metadata
字典参数。当存在 nnx.PARTITION_NAME
键时,sharding
元数据将根据 in_axes
和 out_axes
的指定进行更新。
让我们来看一个实际的例子
class Weights(nnx.Module):
def __init__(self, array: jax.Array, sharding: tuple[str | None, ...]):
self.param = nnx.Param(array, sharding=sharding)
m = Weights(jnp.ones((3, 4, 5)), sharding=('a', 'b', None))
@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
print(f'Inner {m.param.shape = }')
print(f'Inner {m.param.sharding = }')
f(m)
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Inner m.param.shape = (3, 5)
Inner m.param.sharding = ('a', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)
在这里,你向 nnx.Param
变量添加了 sharding
元数据,并使用 transform_metadata
来更新 sharding
元数据,以反映轴的变化。具体来说,你可以看到,当在 nnx.vmap
内部时,第一个轴 b
从 sharding
元数据中删除,然后在 nnx.vmap
外部时添加回来。
你可以验证,当在转换内部创建 nnx.Module
时,这也有效 - 新的 sharding
轴将被添加到转换外部的 nnx.Module
nnx.Variable
中,从而与转换后的 nnx.Variable
的轴匹配。
@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
return Weights(jnp.ones((3, 5)), sharding=('a', None))
m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)