Flax 基础知识#

Flax NNX 是一个新的简化 API,旨在简化在 JAX 中创建、检查、调试和分析神经网络的过程。它通过为 Python 引用语义添加一流的支持来实现这一点。这允许用户使用常规 Python 对象来表达他们的模型,这些对象被建模为 PyGraphs(而不是 pytrees),从而实现引用共享和可变性。这种 API 设计应该会让 PyTorch 或 Keras 用户感到宾至如归。

首先,使用 pip 安装 Flax 并导入必要的依赖项

# ! pip install -U flax
from flax import nnx
import jax
import jax.numpy as jnp

Flax NNX 模块系统#

Flax nnx.ModuleFlax LinenHaiku 中其他 Module 系统之间的主要区别在于,在 NNX 中,一切都是显式的。这意味着,除其他外,nnx.Module 本身直接持有状态(例如参数),PRNG 状态由用户线程化,并且所有形状信息必须在初始化时提供(没有形状推断)。

让我们首先创建一个 Linear nnx.Module。如下所示,动态状态通常存储在 nnx.Param 中,静态状态(所有 NNX 未处理的类型),例如整数或字符串,则直接存储。类型为 jax.Arraynumpy.ndarray 的属性也被视为动态状态,尽管将它们存储在 nnx.Variable(例如 Param)中是首选的。此外,nnx.Rngs 对象可用于根据传递给构造函数的根 PRNG 密钥获取新的唯一密钥。

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

另请注意,可以使用 value 属性访问 nnx.Variable 的内部值,但为方便起见,它们实现了所有数值运算符,并且可以直接在算术表达式中使用(如上面的代码所示)。

要初始化 Flax nnx.Module,只需调用构造函数,并且通常会急切地创建 Module 的所有参数。由于 nnx.Module 持有自己的状态方法,您可以直接调用它们,而无需单独的 apply 方法。这对于调试非常方便,允许您直接检查模型的整个结构。

model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))

print(y)
nnx.display(model)
[[1.245453   0.74195766 0.8553282  0.6763327  1.2617068 ]]

上面的 nnx.display 可视化是使用出色的 Treescope 库生成的。

有状态计算#

实现诸如 nnx.BatchNorm 之类的层需要在前向传递期间执行状态更新。在 Flax NNX 中,您只需创建一个 nnx.Variable 并在前向传递期间更新其 .value

class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))

  def __call__(self):
    self.count += 1

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)

通常在 JAX 中避免使用可变引用。但是 Flax NNX 提供了完善的机制来处理它们,如本指南的后续部分所示。

嵌套模块#

Flax nnx.Module 可用于在嵌套结构中组合其他 Module。这些可以直接作为属性分配,或者分配在任何(嵌套)pytree 类型的属性内部,例如 listdicttuple 等。

下面的示例展示了如何通过子类化 nnx.Module 来定义一个简单的 MLP。该模型由两个 Linear 层、一个 nnx.Dropout 层和一个 nnx.BatchNorm 层组成。

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

model = MLP(2, 16, 5, rngs=nnx.Rngs(0))

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

在 Flax 中,nnx.Dropout 是一个有状态的模块,它存储一个 nnx.Rngs 对象,以便它可以在前向传递期间生成新的掩码,而无需用户每次都传递新的密钥。

模型手术#

默认情况下,Flax nnx.Module 是可变的。这意味着它们的结构可以随时更改,这使得 模型手术 非常容易,因为任何子 Module 属性都可以替换为任何其他内容,例如新的 Module、现有的共享 Module、不同类型的 Module 等。此外,nnx.Variable 也可以修改或替换/共享。

以下示例展示了如何将上一个示例中 MLP 模型中的 Linear 层替换为 LoraLinear

class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
    self.linear = linear
    self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
    self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)

# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

Flax 转换#

Flax NNX 转换 (transforms) 扩展了 JAX 转换以支持 nnx.Module 和其他对象。它们作为其等效 JAX 对手的超集,并增加了对对象状态的感知,并提供了其他 API 来转换它。

Flax 转换的主要功能之一是保留引用语义,这意味着只要在转换规则内合法,发生在转换内部的对象图的任何突变都会向外传播。在实践中,这意味着可以使用命令式代码来表达 Flax 程序,从而大大简化用户体验。

在下面的示例中,您定义一个 train_step 函数,该函数接受一个 MLP 模型、一个 nnx.Optimizer 和一批数据,并返回该步骤的损失。损失和梯度是使用 nnx.value_and_grad 转换在 loss_fn 上计算的。梯度被传递给优化器的 nnx.Optimizer.update 方法以更新 model 的参数。

import optax

# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # In place updates.

  return loss

x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)

print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000255, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)

在此示例中发生了两件值得一提的事情

  1. nnx.BatchNormnnx.Dropout 层的状态的更新,会自动从 loss_fn 内部传播到 train_step,一直到外部的 model 引用。

  2. optimizer 持有对 model 的可变引用 - 这种关系在 train_step 函数内部被保留,使得仅使用优化器即可更新模型的参数。

注意
对于小型模型,nnx.jit 具有性能开销,请查看 性能注意事项 指南以获取更多信息。

扫描层#

下一个示例使用 Flax nnx.vmap 创建多个 MLP 层的堆栈,并使用 nnx.scan 将堆栈中的每一层迭代应用于输入。

在下面的代码中,请注意以下几点

  1. 自定义的 create_model 函数接收一个键并返回一个 MLP 对象,因为您创建了五个键并使用 nnx.vmapcreate_model 进行操作,因此创建了一个包含 5 个 MLP 对象的堆栈。

  2. nnx.scan 用于将堆栈中的每个 MLP 迭代应用于输入 x

  3. nnx.scan (有意识地)偏离了 jax.lax.scan,而是模仿了更具表现力的 nnx.vmapnnx.scan 允许指定多个输入、每个输入/输出的扫描轴以及进位的位置。

  4. 对于 nnx.BatchNormnnx.Dropout 层,State 的更新由 nnx.scan 自动传播。

@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
  return MLP(10, 32, 10, rngs=nnx.Rngs(key))

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)

@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
  x = model(x)
  return x

x = jnp.ones((3, 10))
y = forward(model, x)

print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)

Flax NNX 转换如何实现这一点?为了理解 Flax NNX 对象如何与 JAX 转换交互,下一节将解释 Flax NNX 函数式 API。

Flax 函数式 API#

Flax NNX 函数式 API 在引用/对象语义和值/pytree 语义之间建立了一个清晰的边界。它还允许对状态进行与 Flax Linen 和 Haiku 用户习惯的相同程度的细粒度控制。Flax NNX 函数式 API 由三个基本方法组成:nnx.splitnnx.mergennx.update

下面是一个 StatefulLinear nnx.Module 的示例,它使用函数式 API。它包含

class Count(nnx.Variable): pass

class StatefulLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = Count(jnp.array(0, dtype=jnp.uint32))

  def __call__(self, x: jax.Array):
    self.count += 1
    return x @ self.w + self.b

model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))

nnx.display(model)

状态和 GraphDef#

可以使用 nnx.split 函数将 Flax nnx.Module 分解为 nnx.Statennx.GraphDef

graphdef, state = nnx.split(model)

nnx.display(graphdef, state)

拆分、合并和更新#

Flax 的 nnx.mergennx.split 的反向操作。它接收 nnx.GraphDef + nnx.State 并重建 nnx.Module。下面的示例演示了这一点,如下所示

  • 通过依次使用 nnx.splitnnx.merge,可以将任何 Module 提升到在任何 JAX 转换中使用。

  • nnx.update 可以使用给定 nnx.State 的内容就地更新对象。

  • 此模式用于将状态从转换传播回外部的源对象。

print(f'{model.count.value = }')

# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)

@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
  # 2. Use `nnx.merge` to create a new model inside the JAX transformation.
  model = nnx.merge(graphdef, state)
  # 3. Call the `nnx.Module`
  y = model(x)
  # 4. Use `nnx.split` to propagate `nnx.State` updates.
  _, state = nnx.split(model)
  return y, state

y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)

print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)

此模式的关键见解是在转换上下文中(包括基本 eager 解释器)使用可变引用是可行的,但在跨越边界时必须使用函数式 API。

为什么模块不只是 pytrees? 主要原因是很容易意外地丢失共享引用的踪迹,例如,如果您通过 JAX 边界传递两个具有共享 Modulennx.Module,您将静默丢失该共享。Flax 的函数式 API 使此行为显式化,因此更容易推理。

细粒度状态控制#

经验丰富的 Flax LinenHaiku API 用户可能会意识到,将所有状态放在一个单一结构中并不总是最佳选择,因为在某些情况下,您可能希望以不同的方式处理状态的不同子集。这是与 JAX 转换交互时常见的现象。

例如

  • jax.grad 交互时,并非每个模型状态都可以或应该被微分。

  • 或者,有时,需要在使用 jax.lax.scan 时指定模型状态的哪一部分是进位,哪一部分不是。

为了解决这个问题,Flax NNX API 提供了 nnx.split,它允许您传递一个或多个 nnx.filterlib.Filters 来将 nnx.Variables 分割成互斥的 nnx.States。Flax NNX 使用 Filter 在 API 中创建 State 组(例如 nnx.splitnnx.state() 和许多 NNX 转换)。

下面的示例展示了最常见的 Filters

# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)

nnx.display(params, counts)

注意: nnx.filterlib.Filters 必须是详尽的,如果一个值没有被匹配,则会引发错误。

正如预期的那样,nnx.mergennx.update 方法自然地会消耗多个 States

# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)