变换#

通常,JAX 变换(transforms)操作 pytreesjax.Array,并遵循值语义。这对于 Flax NNX 来说是一个挑战,因为它将 nnx.Module 表示为遵循引用语义的常规 Python 对象。为了解决这个问题,Flax NNX 引入了自己的一组变换,扩展了 JAX 变换,允许 nnx.Module 和其他 Flax NNX 对象在变换中传入和传出,同时保留引用语义。

如果您以前使用过 JAX 变换,那么 Flax NNX 变换应该会非常熟悉。它们使用相同的 API,并且在仅处理 jax.Array 的 pytrees 时,其行为类似于 JAX 变换。但是,当处理 Flax NNX 对象时,它们允许为这些对象保留 Python 的引用语义,这包括

  • 保留变换输入和输出中多个对象之间的共享引用。

  • 将变换内部对对象进行的任何状态更改传播到变换外部的对象。

  • 当多个输入和输出之间存在别名时,强制对象如何变换的一致性。

import jax
from jax import numpy as jnp, random
from flax import nnx

在本指南中,nnx.vmap 被用作一个案例研究,以演示 Flax NNX 变换的工作原理。但是,本文档中概述的原则适用于所有变换。

基本示例#

首先,让我们看一个使用 nnx.vmap 将元素级的 vector_dot 函数扩展到处理批量输入的简单示例。我们将定义一个没有方法的 Weights 模块来保存一些参数,这些权重将作为输入传递给 vector_dot 函数以及一些数据。权重和数据都将在轴 0 上批量处理,我们将使用 nnx.vmapvector_dot 应用于每个批次元素,结果将在轴 1 上批量处理。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)

请注意,in_axesWeights 模块自然地交互,将其视为 jax.Array 的 pytree。也允许使用前缀模式,因此在这种情况下,in_axes=(0, 0) 也会起作用。

对象也允许作为 Flax NNX 变换的输出,这对于变换初始化器很有用。例如,您可以定义一个 create_weights 函数来创建一个单一的 Weights nnx.Module,并使用 nnx.vmap 创建一个与之前形状相同的 Weights 堆栈。

def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )

seeds = jnp.arange(10)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)

变换方法#

Python 中的方法只是将实例作为第一个参数的函数,这意味着您可以装饰 Module 和其他 Flax NNX 子类型的方法。例如,我们可以重构前面示例中的 Weights,并使用 vmap 装饰 __init__ 来完成 create_weights 的工作,并添加一个 __call__ 方法并使用 @nnx.vmap 装饰它来完成 vector_dot 的工作。

class WeightStack(nnx.Module):
  @nnx.vmap
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=0, out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)

本指南的其余部分将重点介绍变换单个函数。但请注意,所有示例都可以用这种方法风格编写。

状态传播#

到目前为止,我们的函数都是无状态的。但是,当您拥有有状态的函数时,Flax NNX 变换的真正威力就体现出来了,因为它们的主要功能之一是传播状态更改以保留引用语义。让我们通过向 Weights 添加一个 count 属性并在新的 stateful_vector_dot 函数中递增它来更新前面的示例。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias


y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)

weights.count
Count(
  value=Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
)

运行一次 stateful_vector_dot 后,您验证了 count 属性已正确更新。由于 Weights 已向量化,count 被初始化为 arange(10),并且它的所有元素在变换内部都递增了 1。最重要的是,更新被传播到变换外部的原始 Weights 对象。太棒了!

图更新传播#

JAX 变换将输入视为 jax.Array 的 pytrees,而 Flax NNX 将输入视为 jax.Array 的 pytrees 和 Python 引用,其中引用形成一个图。只要它们是输入本地的(不支持变换内部对全局变量的更新),Flax NNX 的状态传播机制就可以跟踪对对象的任意更新。

这意味着您可以根据需要修改图结构,包括更新现有属性、添加/删除属性、交换属性、在对象之间共享(新的)引用、在对象之间共享 nnx.Variable 等。一切皆有可能!

以下示例演示了在 nnx.vmap 内部对 Weights 对象执行一些任意更新,并验证这些更新是否正确传播到变换外部的原始 Weights 对象。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def crazy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  weights.some_property = ['a', 2, False] # add attribute
  del weights.bias # delete attribute
  weights.new_param = weights.kernel # share reference
  return y

y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)

nnx.display(weights)

能力越大,责任越大。
- 本叔

虽然此功能非常强大,但必须谨慎使用,因为它可能会与 JAX 对某些变换的底层假设发生冲突。例如,jit 希望输入的结构是稳定的,以便缓存编译后的函数,因此在 nnx.jit-ed 函数内部更改图结构会导致持续的重新编译和性能下降。另一方面,scan 只允许固定的 carry 结构,因此添加/删除声明为 carry 的子状态会导致错误。

变换子状态(提升类型)#

某些 JAX 变换允许使用 pytree 前缀来指定应该如何变换输入/输出的不同部分。Flax NNX 支持 pytree 结构的 pytree 前缀,但目前它没有图对象的前缀概念。相反,Flax NNX 引入了“提升类型”的概念,该概念允许指定应该如何变换对象的不同子状态。不同的变换支持不同的提升类型,以下是每个 JAX 变换当前支持的 FLax NNX 提升类型列表

提升类型

JAX 变换

StateAxes

vmappmapscan

StateSharding

jitshard_map*

DiffState

gradvalue_and_gradcustom_vjp

注意:* 在编写本文档的此版本时,Flax NNX shard_map 尚未实现。

为了指定如何在 nnx.vmap 中向量化对象的不同子状态,Flax 团队创建了一个 nnx.StateAxesStateAxes 通过 Flax NNX 过滤器将一组子状态映射到其相应的轴,你可以将 nnx.StateAxes 传递给 in_axesout_axes,就像它/它们是 pytree 前缀一样。

让我们使用之前的 stateful_vector_dot 示例,仅向量化 nnx.Param 变量,并广播 count 变量,这样我们只为所有批次元素保留一个计数。为此,我们将定义一个 nnx.StateAxes,其中包含一个匹配 nnx.Param 变量的过滤器,并将它们映射到轴 0,并将所有 Count 变量映射到 None,并将此 nnx.StateAxes 传递给 in_axes 作为 Weights 对象。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.array(0),
)
x = jax.random.normal(random.key(1), (10, 2))


def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)

weights.count
Count(
  value=Array(1, dtype=int32, weak_type=True)
)

在这里,count 现在是一个标量,因为它没有被向量化。另请注意,nnx.StateAxes 只能直接用于 Flax NNX 对象,不能用作对象 pytree 的前缀。

随机状态#

在 Flax NNX 中,随机状态只是一个常规状态。这意味着它存储在需要它的 nnx.Module 中,并且被视为任何其他类型的状态。这是对 Flax Linen 的简化,在 Flax Linen 中,随机状态由单独的机制处理。实际上,nnx.Module 只需要保留对初始化期间传递给它们的 Rngs 对象的引用,并使用它为每个随机操作生成唯一的键。就本指南而言,这意味着随机状态可以像任何其他类型的状态一样进行转换,但我们也需要了解状态的布局方式,以便我们可以正确地转换它。

假设你想稍微改变一下,并将相同的权重应用于批次中的所有元素。但你也想为每个元素添加不同的随机噪声。

为此,你将向 Weights 添加一个 Rngs 属性,该属性从构造期间传递的 seed 键参数创建。此种子键必须事先被 split,以便你可以成功地对其进行向量化。出于教学原因,你将种子键分配给一个 noise “流”,并从中进行采样。要向量化 PRNG 状态,你必须配置 nnx.StateAxes 将所有 RngStateRngs 中所有变量的基类)映射到轴 0,并将 nnx.ParamCount 映射到 None

class Weights(nnx.Module):
  def __init__(self, kernel, bias, count, seed):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)
    self.rngs = nnx.Rngs(noise=seed)

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=random.split(random.key(0), num=10),
)
x = random.normal(random.key(1), (10, 2))

def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

由于 Rngs 的状态是就地更新的,并且由 nnx.vmap 自动传播,因此每次调用 noisy_vector_dot 时,我们都会得到不同的结果。

在上面的示例中,你是在构造期间手动拆分随机状态。这没问题,因为它清楚地表明了意图,但它也不允许你在 nnx.vmap 之外使用 Rngs,因为它的状态始终是拆分的。为了解决这个问题,你可以传递一个未拆分的种子,并在 nnx.vmap 之前使用 nnx.split_rngs 装饰器,以便在每次调用函数之前拆分 RngState,然后将其“降低”回来,使其变得可用。

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=0,
)
x = random.normal(random.key(1), (10, 2))

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})

@nnx.split_rngs(splits=10)
@nnx.vmap(in_axes=(state_axes, 0))
def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

规则和限制#

在本节中,我们将介绍在转换内部使用模块时适用的一些规则和限制。

可变模块不能通过闭包传递#

虽然 Python 允许将对象作为闭包传递给函数,但 Flax NNX 转换通常不支持这样做。原因是,由于模块是可变的,因此很容易将跟踪器捕获到在转换外部创建的模块中,这是 JAX 中的静默错误。为了避免这种情况,Flax NNX 会检查被修改的模块和变量是否作为参数传递给转换后的函数。

例如,如果我们有一个有状态的模块(例如 Counter),每次调用它时都会递增一个计数器,并且我们尝试将其作为闭包传递给用 nnx.jit 修饰的函数,我们将泄漏跟踪器。但是,Flax NNX 会引发错误,以防止这种情况发生

class Counter(nnx.Module):
  def __init__(self):
    self.count = nnx.Param(jnp.array(0))

  def increment(self):
    self.count += jnp.array(1)

counter = Counter()

@nnx.jit
def f(x):
  counter.increment()
  return 2 * x

try:
  y = f(3)
except Exception as e:
  print(e)

要解决此问题,请将所有模块作为参数传递给正在转换的函数。在这种情况下,f 应该接受 counter 作为参数。

一致的别名#

在转换中允许引用语义的主要问题是,引用可以在输入和输出之间共享。如果不加以处理,这可能会导致不明确或不一致的行为。在下面的示例中,你有一个单独的 Weights 模块 m,其引用在 arg1arg2 中的多个位置出现。这里的问题是,你还指定要以轴 0 向量化 arg1,并以轴 1 向量化 arg2。这在 JAX 中没问题,因为 pytree 具有引用透明性。但这在 Flax NNX 中是有问题的,因为你尝试以两种不同的方式向量化 m。Flax NNX 将通过引发错误来强制执行一致性。

class Weights(nnx.Module):
  def __init__(self, array: jax.Array):
    self.param = nnx.Param(array)

m = Weights(jnp.arange(10))
arg1 = {'a': {'b': m}, 'c': m}
arg2 = [(m, m), m]

@nnx.vmap(in_axes=(0, 1))
def f(arg1, arg2):
  ...

try:
  f(arg1, arg2)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
  param: 0
  param: 0
  param: 1

输入和输出之间也可能发生不一致的别名。在下一个示例中,你有一个简单的函数,它接受并立即返回 arg1。但是,arg1 在输入上以轴 0 向量化,在输出上以轴 1 向量化。正如预期的那样,这有问题,Flax NNX 会引发错误。

@nnx.vmap(in_axes=0, out_axes=1)
def f(arg1):
  return arg1

try:
  f(arg1)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variablelib.Param'>
  param: 0
  param: 0
  param: 1

轴元数据#

Flax NNX Variable 可以保存任意元数据,这些元数据可以通过将其作为关键字参数传递给其构造函数来添加。这通常用于存储 sharding 信息,如 nnx.spmd API(如 nnx.get_partition_specnnx.get_named_sharding)所使用。

但是,在涉及转换时,使此轴相关的信息与轴的实际状态保持同步通常很重要。例如,如果在轴 1 上向量化一个变量,则应在 vmapscan 内部删除位置 1sharding 信息,以反映轴被临时删除的事实。

为了实现这一点,Flax NNX 转换提供了一个非标准的 transform_metadata 字典参数。当存在 nnx.PARTITION_NAME 键时,sharding 元数据将根据 in_axesout_axes 的指定进行更新。

让我们来看一个实际的例子

class Weights(nnx.Module):
  def __init__(self, array: jax.Array, sharding: tuple[str | None, ...]):
    self.param = nnx.Param(array, sharding=sharding)

m = Weights(jnp.ones((3, 4, 5)), sharding=('a', 'b', None))

@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
  print(f'Inner {m.param.shape = }')
  print(f'Inner {m.param.sharding = }')

f(m)
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Inner m.param.shape = (3, 5)
Inner m.param.sharding = ('a', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)

在这里,你向 nnx.Param 变量添加了 sharding 元数据,并使用 transform_metadata 来更新 sharding 元数据,以反映轴的变化。具体来说,你可以看到,当在 nnx.vmap 内部时,第一个轴 bsharding 元数据中删除,然后在 nnx.vmap 外部时添加回来。

你可以验证,当在转换内部创建 nnx.Module 时,这也有效 - 新的 sharding 轴将被添加到转换外部的 nnx.Module nnx.Variable 中,从而与转换后的 nnx.Variable 的轴匹配。

@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
  return Weights(jnp.ones((3, 5)), sharding=('a', None))

m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)