为何选择 Flax NNX?#
2020 年,Flax 团队发布了 Flax Linen API,以支持 JAX 上的建模研究,重点是扩展和性能。此后,我们从用户那里学到了很多。该团队引入了一些被证明对用户有益的想法,例如
将变量组织成集合。
自动且高效的伪随机数生成器 (PRNG) 管理。
变量元数据,用于 单程序多数据 (SPMD) 注释、优化器元数据和其他用例。
Flax 团队所做的选择之一是使用函数式(compact
)语义,通过参数的惰性初始化来进行神经网络编程。这使得实现代码简洁,并使 Flax Linen API 与 Haiku 对齐。
然而,这也意味着 Flax 中模块和变量的语义是非 Python 式的,并且常常令人惊讶。它还导致了实现复杂性,并模糊了神经网络上的转换 (transforms)的核心思想。
Flax NNX 简介#
快进到 2024 年,Flax 团队开发了 Flax NNX,旨在保留 Flax Linen 对用户有用的功能,同时引入一些新原则。Flax NNX 背后的核心思想是在 JAX 中引入引用语义。以下是其主要功能
NNX 是 Python 式的:模块的常规 Python 语义,包括对可变性和共享引用的支持。
NNX 很简单:Flax Linen 中的许多复杂 API 都使用 Python 习惯用法进行了简化或完全删除。
更好的 JAX 集成:自定义 NNX 转换采用与 JAX 转换相同的 API。使用 NNX 可以更容易地直接使用 JAX 转换(高阶函数)。
以下是一个简单的 Flax NNX 程序示例,说明了以上许多要点
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # Eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # Reference sharing.
@nnx.jit # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # in-place updates
return loss
Flax NNX 对 Linen 的改进#
本文档的其余部分使用各种示例来演示 Flax NNX 如何改进 Flax Linen。
检查#
第一个改进是 Flax NNX 模块是常规 Python 对象。这意味着您可以轻松地构造和检查 Module
对象。
另一方面,Flax Linen 模块不容易检查和调试,因为它们是惰性的,这意味着某些属性在构造时不可用,只能在运行时访问。
class Block(nn.Module):
def setup(self):
self.linear = nn.Dense(10)
block = Block()
try:
block.linear # AttributeError: "Block" object has no attribute "linear".
except AttributeError as e:
pass
...
class Block(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(5, 10, rngs=rngs)
block = Block(nnx.Rngs(0))
block.linear
# Linear(
# kernel=Param(
# value=Array(shape=(5, 10), dtype=float32)
# ),
# bias=Param(
# value=Array(shape=(10,), dtype=float32)
# ),
# ...
请注意,在上面的 Flax NNX 示例中,没有形状推断 - 输入和输出形状都必须提供给 Linear
nnx.Module
。这是一个权衡,允许更明确和可预测的行为。
运行计算#
在 Flax Linen 中,所有顶层计算都必须通过 flax.linen.Module.init
或 flax.linen.Module.apply
方法完成,并且参数或任何其他类型的状态都作为单独的结构处理。这在以下两者之间创建了不对称性:1) 在 apply
内部运行的代码,可以直接运行方法和其他 Module
对象;以及 2) 在 apply
外部运行的代码,必须使用 apply
方法。
在 Flax NNX 中,没有特殊的上下文,因为参数作为属性保存,并且可以直接调用方法。这意味着您的 NNX 模块的 __init__
和 __call__
方法与其他类方法的处理方式不同,而 Flax Linen 模块的 setup()
和 __call__
方法是特殊的。
Encoder = lambda: nn.Dense(10)
Decoder = lambda: nn.Dense(2)
class AutoEncoder(nn.Module):
def setup(self):
self.encoder = Encoder()
self.decoder = Decoder()
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
x = jnp.ones((1, 2))
model = AutoEncoder()
params = model.init(random.key(0), x)['params']
y = model.apply({'params': params}, x)
z = model.apply({'params': params}, x, method='encode')
y = Decoder().apply({'params': params['decoder']}, z)
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)
class AutoEncoder(nnx.Module):
def __init__(self, rngs):
self.encoder = Encoder(rngs)
self.decoder = Decoder(rngs)
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
x = jnp.ones((1, 2))
model = AutoEncoder(nnx.Rngs(0))
y = model(x)
z = model.encode(x)
y = model.decoder(z)
在 Flax Linen 中,直接调用子模块是不可能的,因为它们没有初始化。因此,您必须做的是构造一个新实例,然后提供一个正确的参数结构。
但是在 Flax NNX 中,您可以直接调用子模块,没有任何问题。
状态处理#
Flax Linen 臭名昭著的复杂领域之一是状态处理。当您使用 Dropout 层、BatchNorm 层或两者都使用时,您突然必须处理新状态,并使用它来配置 flax.linen.Module.apply
方法。
在 Flax NNX 中,状态保存在 nnx.Module
中,并且是可变的,这意味着它可以直接调用。
class Block(nn.Module):
train: bool
def setup(self):
self.linear = nn.Dense(10)
self.bn = nn.BatchNorm(use_running_average=not self.train)
self.dropout = nn.Dropout(0.1, deterministic=not self.train)
def __call__(self, x):
return nn.relu(self.dropout(self.bn(self.linear(x))))
x = jnp.ones((1, 5))
model = Block(train=True)
vs = model.init(random.key(0), x)
params, batch_stats = vs['params'], vs['batch_stats']
y, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
x,
rngs={'dropout': random.key(1)},
mutable=['batch_stats'],
)
batch_stats = updates['batch_stats']
class Block(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(5, 10, rngs=rngs)
self.bn = nnx.BatchNorm(10, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.dropout(self.bn(self.linear(x))))
x = jnp.ones((1, 5))
model = Block(nnx.Rngs(0))
y = model(x)
...
Flax NNX 的状态处理的主要好处是,当您添加新的有状态层时,不必更改训练代码。
此外,在 Flax NNX 中,处理状态的层也很容易实现。下面是 BatchNorm
层的简化版本,该层在每次调用时都会更新均值和方差。
class BatchNorm(nnx.Module):
def __init__(self, features: int, mu: float = 0.95):
# Variables
self.scale = nnx.Param(jax.numpy.ones((features,)))
self.bias = nnx.Param(jax.numpy.zeros((features,)))
self.mean = nnx.BatchStat(jax.numpy.zeros((features,)))
self.var = nnx.BatchStat(jax.numpy.ones((features,)))
self.mu = mu # Static
def __call__(self, x):
mean = jax.numpy.mean(x, axis=-1)
var = jax.numpy.var(x, axis=-1)
# ema updates
self.mean.value = self.mu * self.mean + (1 - self.mu) * mean
self.var.value = self.mu * self.var + (1 - self.mu) * var
# normalize and scale
x = (x - mean) / jax.numpy.sqrt(var + 1e-5)
return x * self.scale + self.bias
模型手术#
在 Flax Linen 中,由于以下两个原因,模型手术在历史上一直具有挑战性
由于惰性初始化,不能保证可以用新模块替换子
Module
。参数结构与
flax.linen.Module
结构分离,这意味着您必须手动保持它们同步。
在 Flax NNX 中,您可以直接按照 Python 语义替换子模块。由于参数是 nnx.Module
结构的一部分,因此它们永远不会不同步。下面是如何实现 LoRA 层,然后用它来替换现有模型中的 Linear
层的示例。
class LoraLinear(nn.Module):
linear: nn.Dense
rank: int
@nn.compact
def __call__(self, x: jax.Array):
A = self.param(random.normal, (x.shape[-1], self.rank))
B = self.param(random.normal, (self.rank, self.linear.features))
return self.linear(x) + x @ A @ B
try:
model = Block(train=True)
model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR
lora_params = model.linear.init(random.key(1), x)
lora_params['linear'] = params['linear']
params['linear'] = lora_params
except AttributeError as e:
pass
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear, rank, rngs):
self.linear = linear
self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank)))
self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = Block(rngs)
model.linear = LoraLinear(model.linear, rank=5, rngs=rngs)
...
如上所示,在 Flax Linen 中,这在这种情况下实际上不起作用,因为 linear
子Module
不可用。但是,其余代码提供了有关如何手动更新 params
结构的想法。
在 Flax Linen 中执行任意模型手术并不容易,目前 intercept_methods API 是执行方法通用修补的唯一方法。但是此 API 不是很好用。
在 Flax NNX 中,要进行通用模型手术,您可以只使用 nnx.iter_graph
,这比在 Linen 中简单得多,也容易得多。下面是一个示例,说明如何将模型中的所有 nnx.Linear
层替换为自定义的 LoraLinear
NNX 层。
rngs = nnx.Rngs(0)
model = Block(rngs)
for path, module in nnx.iter_graph(model):
if isinstance(module, nnx.Module):
for name, value in vars(module).items():
if isinstance(value, nnx.Linear):
setattr(module, name, LoraLinear(value, rank=5, rngs=rngs))
转换#
Flax Linen 转换非常强大,因为它们可以对模型的状态进行细粒度控制。但是,Flax Linen 转换也有缺点,例如
它们公开了不属于 JAX 的其他 API,使其行为令人困惑,有时与其 JAX 对应项不同。这也限制了您与 JAX 转换交互的方式,并跟上 JAX API 的更改。
它们适用于具有非常特定签名的函数,即
flax.linen.Module
必须是第一个参数。它们接受其他
Module
对象作为参数,但不作为返回值。
它们只能在
flax.linen.Module.apply
内部使用。
另一方面,Flax NNX 转换旨在与其相应的 JAX 转换等效,但有一个例外 - 它们可以用于 Flax NNX 模块。这意味着 Flax 转换
具有与 JAX 转换相同的 API。
可以在任何参数上接受 Flax NNX 模块,并且可以从中返回
nnx.Module
对象。可以在任何地方使用,包括训练循环。
下面是一个使用 Flax NNX 的 vmap
的示例,通过转换 create_weights
函数来创建权重堆栈,该函数返回一些 Weights
,并通过转换 vector_dot
函数将该权重堆栈单独应用于一批输入,该函数将 Weights
作为第一个参数,一批输入作为第二个参数。
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
def create_weights(seed: jax.Array):
return Weights(
kernel=random.uniform(random.key(seed), (2, 3)),
bias=jnp.zeros((3,)),
)
def vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ weights.kernel + weights.bias
seeds = jnp.arange(10)
weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds)
x = jax.random.normal(random.key(1), (10, 2))
y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x)
与 Flax Linen 转换相反,in_axes
参数和其他 API 确实会影响 nnx.Module
状态的转换方式。
此外,Flax NNX 转换可以用作方法装饰器,因为 nnx.Module
方法只是将 Module
作为第一个参数的函数。这意味着前面的示例可以重写如下
class WeightStack(nnx.Module):
@nnx.vmap(in_axes=(0, 0), out_axes=0)
def __init__(self, seed: jax.Array):
self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
self.bias = nnx.Param(jnp.zeros((3,)))
@nnx.vmap(in_axes=(0, 0), out_axes=1)
def __call__(self, x: jax.Array):
assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ self.kernel + self.bias
weights = WeightStack(jnp.arange(10))
x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)