为何选择 Flax NNX?#

2020 年,Flax 团队发布了 Flax Linen API,以支持 JAX 上的建模研究,重点是扩展和性能。此后,我们从用户那里学到了很多。该团队引入了一些被证明对用户有益的想法,例如

Flax 团队所做的选择之一是使用函数式(compact)语义,通过参数的惰性初始化来进行神经网络编程。这使得实现代码简洁,并使 Flax Linen API 与 Haiku 对齐。

然而,这也意味着 Flax 中模块和变量的语义是非 Python 式的,并且常常令人惊讶。它还导致了实现复杂性,并模糊了神经网络上的转换 (transforms)的核心思想。

Flax NNX 简介#

快进到 2024 年,Flax 团队开发了 Flax NNX,旨在保留 Flax Linen 对用户有用的功能,同时引入一些新原则。Flax NNX 背后的核心思想是在 JAX 中引入引用语义。以下是其主要功能

  • NNX 是 Python 式的:模块的常规 Python 语义,包括对可变性和共享引用的支持。

  • NNX 很简单:Flax Linen 中的许多复杂 API 都使用 Python 习惯用法进行了简化或完全删除。

  • 更好的 JAX 集成:自定义 NNX 转换采用与 JAX 转换相同的 API。使用 NNX 可以更容易地直接使用 JAX 转换(高阶函数)

以下是一个简单的 Flax NNX 程序示例,说明了以上许多要点

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # Eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # Reference sharing.

@nnx.jit  # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates

  return loss

Flax NNX 对 Linen 的改进#

本文档的其余部分使用各种示例来演示 Flax NNX 如何改进 Flax Linen。

检查#

第一个改进是 Flax NNX 模块是常规 Python 对象。这意味着您可以轻松地构造和检查 Module 对象。

另一方面,Flax Linen 模块不容易检查和调试,因为它们是惰性的,这意味着某些属性在构造时不可用,只能在运行时访问。

class Block(nn.Module):
  def setup(self):
    self.linear = nn.Dense(10)

block = Block()

try:
  block.linear  # AttributeError: "Block" object has no attribute "linear".
except AttributeError as e:
  pass





...
class Block(nnx.Module):
  def __init__(self, rngs):
    self.linear = nnx.Linear(5, 10, rngs=rngs)

block = Block(nnx.Rngs(0))


block.linear
# Linear(
#   kernel=Param(
#     value=Array(shape=(5, 10), dtype=float32)
#   ),
#   bias=Param(
#     value=Array(shape=(10,), dtype=float32)
#   ),
#   ...

请注意,在上面的 Flax NNX 示例中,没有形状推断 - 输入和输出形状都必须提供给 Linear nnx.Module。这是一个权衡,允许更明确和可预测的行为。

运行计算#

在 Flax Linen 中,所有顶层计算都必须通过 flax.linen.Module.initflax.linen.Module.apply 方法完成,并且参数或任何其他类型的状态都作为单独的结构处理。这在以下两者之间创建了不对称性:1) 在 apply 内部运行的代码,可以直接运行方法和其他 Module 对象;以及 2) 在 apply 外部运行的代码,必须使用 apply 方法。

在 Flax NNX 中,没有特殊的上下文,因为参数作为属性保存,并且可以直接调用方法。这意味着您的 NNX 模块的 __init____call__ 方法与其他类方法的处理方式不同,而 Flax Linen 模块的 setup()__call__ 方法是特殊的。

Encoder = lambda: nn.Dense(10)
Decoder = lambda: nn.Dense(2)

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder()
    self.decoder = Decoder()

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

x = jnp.ones((1, 2))
model = AutoEncoder()
params = model.init(random.key(0), x)['params']

y = model.apply({'params': params}, x)
z = model.apply({'params': params}, x, method='encode')
y = Decoder().apply({'params': params['decoder']}, z)
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

x = jnp.ones((1, 2))
model = AutoEncoder(nnx.Rngs(0))


y = model(x)
z = model.encode(x)
y = model.decoder(z)

在 Flax Linen 中,直接调用子模块是不可能的,因为它们没有初始化。因此,您必须做的是构造一个新实例,然后提供一个正确的参数结构。

但是在 Flax NNX 中,您可以直接调用子模块,没有任何问题。

状态处理#

Flax Linen 臭名昭著的复杂领域之一是状态处理。当您使用 Dropout 层、BatchNorm 层或两者都使用时,您突然必须处理新状态,并使用它来配置 flax.linen.Module.apply 方法。

在 Flax NNX 中,状态保存在 nnx.Module 中,并且是可变的,这意味着它可以直接调用。

class Block(nn.Module):
  train: bool

  def setup(self):
    self.linear = nn.Dense(10)
    self.bn = nn.BatchNorm(use_running_average=not self.train)
    self.dropout = nn.Dropout(0.1, deterministic=not self.train)

  def __call__(self, x):
    return nn.relu(self.dropout(self.bn(self.linear(x))))

x = jnp.ones((1, 5))
model = Block(train=True)
vs = model.init(random.key(0), x)
params, batch_stats = vs['params'], vs['batch_stats']

y, updates = model.apply(
  {'params': params, 'batch_stats': batch_stats},
  x,
  rngs={'dropout': random.key(1)},
  mutable=['batch_stats'],
)
batch_stats = updates['batch_stats']
class Block(nnx.Module):


  def __init__(self, rngs):
    self.linear = nnx.Linear(5, 10, rngs=rngs)
    self.bn = nnx.BatchNorm(10, rngs=rngs)
    self.dropout = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.dropout(self.bn(self.linear(x))))

x = jnp.ones((1, 5))
model = Block(nnx.Rngs(0))



y = model(x)





...

Flax NNX 的状态处理的主要好处是,当您添加新的有状态层时,不必更改训练代码。

此外,在 Flax NNX 中,处理状态的层也很容易实现。下面是 BatchNorm 层的简化版本,该层在每次调用时都会更新均值和方差。

class BatchNorm(nnx.Module):
  def __init__(self, features: int, mu: float = 0.95):
    # Variables
    self.scale = nnx.Param(jax.numpy.ones((features,)))
    self.bias = nnx.Param(jax.numpy.zeros((features,)))
    self.mean = nnx.BatchStat(jax.numpy.zeros((features,)))
    self.var = nnx.BatchStat(jax.numpy.ones((features,)))
    self.mu = mu  # Static

def __call__(self, x):
  mean = jax.numpy.mean(x, axis=-1)
  var = jax.numpy.var(x, axis=-1)
  # ema updates
  self.mean.value = self.mu * self.mean + (1 - self.mu) * mean
  self.var.value = self.mu * self.var + (1 - self.mu) * var
  # normalize and scale
  x = (x - mean) / jax.numpy.sqrt(var + 1e-5)
  return x * self.scale + self.bias

模型手术#

在 Flax Linen 中,由于以下两个原因,模型手术在历史上一直具有挑战性

  1. 由于惰性初始化,不能保证可以用新模块替换子Module

  2. 参数结构与 flax.linen.Module 结构分离,这意味着您必须手动保持它们同步。

在 Flax NNX 中,您可以直接按照 Python 语义替换子模块。由于参数是 nnx.Module 结构的一部分,因此它们永远不会不同步。下面是如何实现 LoRA 层,然后用它来替换现有模型中的 Linear 层的示例。

class LoraLinear(nn.Module):
  linear: nn.Dense
  rank: int

  @nn.compact
  def __call__(self, x: jax.Array):
    A = self.param(random.normal, (x.shape[-1], self.rank))
    B = self.param(random.normal, (self.rank, self.linear.features))

    return self.linear(x) + x @ A @ B

try:
  model = Block(train=True)
  model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR

  lora_params = model.linear.init(random.key(1), x)
  lora_params['linear'] = params['linear']
  params['linear'] = lora_params

except AttributeError as e:
  pass
class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear, rank, rngs):
    self.linear = linear
    self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank)))
    self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = Block(rngs)
model.linear = LoraLinear(model.linear, rank=5, rngs=rngs)






...

如上所示,在 Flax Linen 中,这在这种情况下实际上不起作用,因为 linearModule 不可用。但是,其余代码提供了有关如何手动更新 params 结构的想法。

在 Flax Linen 中执行任意模型手术并不容易,目前 intercept_methods API 是执行方法通用修补的唯一方法。但是此 API 不是很好用。

在 Flax NNX 中,要进行通用模型手术,您可以只使用 nnx.iter_graph,这比在 Linen 中简单得多,也容易得多。下面是一个示例,说明如何将模型中的所有 nnx.Linear 层替换为自定义的 LoraLinear NNX 层。

rngs = nnx.Rngs(0)
model = Block(rngs)

for path, module in nnx.iter_graph(model):
  if isinstance(module, nnx.Module):
    for name, value in vars(module).items():
      if isinstance(value, nnx.Linear):
        setattr(module, name, LoraLinear(value, rank=5, rngs=rngs))

转换#

Flax Linen 转换非常强大,因为它们可以对模型的状态进行细粒度控制。但是,Flax Linen 转换也有缺点,例如

  1. 它们公开了不属于 JAX 的其他 API,使其行为令人困惑,有时与其 JAX 对应项不同。这也限制了您与 JAX 转换交互的方式,并跟上 JAX API 的更改。

  2. 它们适用于具有非常特定签名的函数,即

  • flax.linen.Module 必须是第一个参数。

  • 它们接受其他 Module 对象作为参数,但不作为返回值。

  1. 它们只能在 flax.linen.Module.apply 内部使用。

另一方面,Flax NNX 转换旨在与其相应的 JAX 转换等效,但有一个例外 - 它们可以用于 Flax NNX 模块。这意味着 Flax 转换

  1. 具有与 JAX 转换相同的 API。

  2. 可以在任何参数上接受 Flax NNX 模块,并且可以从中返回 nnx.Module 对象。

  3. 可以在任何地方使用,包括训练循环。

下面是一个使用 Flax NNX 的 vmap 的示例,通过转换 create_weights 函数来创建权重堆栈,该函数返回一些 Weights,并通过转换 vector_dot 函数将该权重堆栈单独应用于一批输入,该函数将 Weights 作为第一个参数,一批输入作为第二个参数。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

seeds = jnp.arange(10)
weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds)

x = jax.random.normal(random.key(1), (10, 2))
y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x)

与 Flax Linen 转换相反,in_axes 参数和其他 API 确实会影响 nnx.Module 状态的转换方式。

此外,Flax NNX 转换可以用作方法装饰器,因为 nnx.Module 方法只是将 Module 作为第一个参数的函数。这意味着前面的示例可以重写如下

class WeightStack(nnx.Module):
  @nnx.vmap(in_axes=(0, 0), out_axes=0)
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=(0, 0), out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)