Flax 基础知识#
Flax NNX 是一个新的简化 API,旨在简化在 JAX 中创建、检查、调试和分析神经网络的过程。它通过为 Python 引用语义添加一流的支持来实现这一点。这允许用户使用常规 Python 对象来表达他们的模型,这些对象被建模为 PyGraphs(而不是 pytrees),从而实现引用共享和可变性。这种 API 设计应该会让 PyTorch 或 Keras 用户感到宾至如归。
首先,使用 pip
安装 Flax 并导入必要的依赖项
# ! pip install -U flax
from flax import nnx
import jax
import jax.numpy as jnp
Flax NNX 模块系统#
Flax nnx.Module
与 Flax Linen 或 Haiku 中其他 Module
系统之间的主要区别在于,在 NNX 中,一切都是显式的。这意味着,除其他外,nnx.Module
本身直接持有状态(例如参数),PRNG 状态由用户线程化,并且所有形状信息必须在初始化时提供(没有形状推断)。
让我们首先创建一个 Linear
nnx.Module
。如下所示,动态状态通常存储在 nnx.Param
中,静态状态(所有 NNX 未处理的类型),例如整数或字符串,则直接存储。类型为 jax.Array
和 numpy.ndarray
的属性也被视为动态状态,尽管将它们存储在 nnx.Variable
(例如 Param
)中是首选的。此外,nnx.Rngs
对象可用于根据传递给构造函数的根 PRNG 密钥获取新的唯一密钥。
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
另请注意,可以使用 value
属性访问 nnx.Variable
的内部值,但为方便起见,它们实现了所有数值运算符,并且可以直接在算术表达式中使用(如上面的代码所示)。
要初始化 Flax nnx.Module
,只需调用构造函数,并且通常会急切地创建 Module
的所有参数。由于 nnx.Module
持有自己的状态方法,您可以直接调用它们,而无需单独的 apply
方法。这对于调试非常方便,允许您直接检查模型的整个结构。
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))
print(y)
nnx.display(model)
[[1.245453 0.74195766 0.8553282 0.6763327 1.2617068 ]]
上面的 nnx.display
可视化是使用出色的 Treescope 库生成的。
有状态计算#
实现诸如 nnx.BatchNorm
之类的层需要在前向传递期间执行状态更新。在 Flax NNX 中,您只需创建一个 nnx.Variable
并在前向传递期间更新其 .value
。
class Count(nnx.Variable): pass
class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count += 1
counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)
通常在 JAX 中避免使用可变引用。但是 Flax NNX 提供了完善的机制来处理它们,如本指南的后续部分所示。
嵌套模块#
Flax nnx.Module
可用于在嵌套结构中组合其他 Module
。这些可以直接作为属性分配,或者分配在任何(嵌套)pytree 类型的属性内部,例如 list
、dict
、tuple
等。
下面的示例展示了如何通过子类化 nnx.Module
来定义一个简单的 MLP
。该模型由两个 Linear
层、一个 nnx.Dropout
层和一个 nnx.BatchNorm
层组成。
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
在 Flax 中,nnx.Dropout
是一个有状态的模块,它存储一个 nnx.Rngs
对象,以便它可以在前向传递期间生成新的掩码,而无需用户每次都传递新的密钥。
模型手术#
默认情况下,Flax nnx.Module
是可变的。这意味着它们的结构可以随时更改,这使得 模型手术 非常容易,因为任何子 Module
属性都可以替换为任何其他内容,例如新的 Module
、现有的共享 Module
、不同类型的 Module
等。此外,nnx.Variable
也可以修改或替换/共享。
以下示例展示了如何将上一个示例中 MLP
模型中的 Linear
层替换为 LoraLinear
层
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
self.linear = linear
self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)
# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
Flax 转换#
Flax NNX 转换 (transforms) 扩展了 JAX 转换以支持 nnx.Module
和其他对象。它们作为其等效 JAX 对手的超集,并增加了对对象状态的感知,并提供了其他 API 来转换它。
Flax 转换的主要功能之一是保留引用语义,这意味着只要在转换规则内合法,发生在转换内部的对象图的任何突变都会向外传播。在实践中,这意味着可以使用命令式代码来表达 Flax 程序,从而大大简化用户体验。
在下面的示例中,您定义一个 train_step
函数,该函数接受一个 MLP
模型、一个 nnx.Optimizer
和一批数据,并返回该步骤的损失。损失和梯度是使用 nnx.value_and_grad
转换在 loss_fn
上计算的。梯度被传递给优化器的 nnx.Optimizer.update
方法以更新 model
的参数。
import optax
# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # Automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # In place updates.
return loss
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)
print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000255, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)
在此示例中发生了两件值得一提的事情
对
nnx.BatchNorm
和nnx.Dropout
层的状态的更新,会自动从loss_fn
内部传播到train_step
,一直到外部的model
引用。optimizer
持有对model
的可变引用 - 这种关系在train_step
函数内部被保留,使得仅使用优化器即可更新模型的参数。
注意
对于小型模型,nnx.jit
具有性能开销,请查看 性能注意事项 指南以获取更多信息。
扫描层#
下一个示例使用 Flax nnx.vmap
创建多个 MLP 层的堆栈,并使用 nnx.scan
将堆栈中的每一层迭代应用于输入。
在下面的代码中,请注意以下几点
自定义的
create_model
函数接收一个键并返回一个MLP
对象,因为您创建了五个键并使用nnx.vmap
对create_model
进行操作,因此创建了一个包含 5 个MLP
对象的堆栈。nnx.scan
用于将堆栈中的每个MLP
迭代应用于输入x
。nnx.scan
(有意识地)偏离了jax.lax.scan
,而是模仿了更具表现力的nnx.vmap
。nnx.scan
允许指定多个输入、每个输入/输出的扫描轴以及进位的位置。对于
nnx.BatchNorm
和nnx.Dropout
层,State
的更新由nnx.scan
自动传播。
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
x = model(x)
return x
x = jnp.ones((3, 10))
y = forward(model, x)
print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)
Flax NNX 转换如何实现这一点?为了理解 Flax NNX 对象如何与 JAX 转换交互,下一节将解释 Flax NNX 函数式 API。
Flax 函数式 API#
Flax NNX 函数式 API 在引用/对象语义和值/pytree 语义之间建立了一个清晰的边界。它还允许对状态进行与 Flax Linen 和 Haiku 用户习惯的相同程度的细粒度控制。Flax NNX 函数式 API 由三个基本方法组成:nnx.split
、nnx.merge
和 nnx.update
。
下面是一个 StatefulLinear
nnx.Module
的示例,它使用函数式 API。它包含
一些
nnx.Param
和nnx.Variable
;以及一个自定义的
Count()
nnx.Variable
类型,用于跟踪每次前向传递时增加的整数标量状态。
class Count(nnx.Variable): pass
class StatefulLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(jnp.array(0, dtype=jnp.uint32))
def __call__(self, x: jax.Array):
self.count += 1
return x @ self.w + self.b
model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))
nnx.display(model)
状态和 GraphDef#
可以使用 nnx.split
函数将 Flax nnx.Module
分解为 nnx.State
和 nnx.GraphDef
。
nnx.State
是一个从字符串到nnx.Variable
或嵌套State
的Mapping
。nnx.GraphDef
包含重建nnx.Module
图所需的所有静态信息,它类似于 JAX 的PyTreeDef
。
graphdef, state = nnx.split(model)
nnx.display(graphdef, state)
拆分、合并和更新#
Flax 的 nnx.merge
是 nnx.split
的反向操作。它接收 nnx.GraphDef
+ nnx.State
并重建 nnx.Module
。下面的示例演示了这一点,如下所示
nnx.update
可以使用给定nnx.State
的内容就地更新对象。此模式用于将状态从转换传播回外部的源对象。
print(f'{model.count.value = }')
# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)
@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
# 2. Use `nnx.merge` to create a new model inside the JAX transformation.
model = nnx.merge(graphdef, state)
# 3. Call the `nnx.Module`
y = model(x)
# 4. Use `nnx.split` to propagate `nnx.State` updates.
_, state = nnx.split(model)
return y, state
y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)
print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)
此模式的关键见解是在转换上下文中(包括基本 eager 解释器)使用可变引用是可行的,但在跨越边界时必须使用函数式 API。
为什么模块不只是 pytrees? 主要原因是很容易意外地丢失共享引用的踪迹,例如,如果您通过 JAX 边界传递两个具有共享 Module
的 nnx.Module
,您将静默丢失该共享。Flax 的函数式 API 使此行为显式化,因此更容易推理。
细粒度状态控制#
经验丰富的 Flax Linen 或 Haiku API 用户可能会意识到,将所有状态放在一个单一结构中并不总是最佳选择,因为在某些情况下,您可能希望以不同的方式处理状态的不同子集。这是与 JAX 转换交互时常见的现象。
例如
与
jax.grad
交互时,并非每个模型状态都可以或应该被微分。或者,有时,需要在使用
jax.lax.scan
时指定模型状态的哪一部分是进位,哪一部分不是。
为了解决这个问题,Flax NNX API 提供了 nnx.split
,它允许您传递一个或多个 nnx.filterlib.Filter
s 来将 nnx.Variable
s 分割成互斥的 nnx.State
s。Flax NNX 使用 Filter
在 API 中创建 State
组(例如 nnx.split
、nnx.state()
和许多 NNX 转换)。
下面的示例展示了最常见的 Filter
s
# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
nnx.display(params, counts)
注意: nnx.filterlib.Filter
s 必须是详尽的,如果一个值没有被匹配,则会引发错误。
正如预期的那样,nnx.merge
和 nnx.update
方法自然地会消耗多个 State
s
# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)