Flax NNX 与 JAX 转换的比较#
本指南描述了 Flax NNX 转换 和 JAX 转换 之间的差异,以及如何无缝切换或并排使用它们。此处的示例将重点关注 nnx.jit
、jax.jit
、nnx.grad
和 jax.grad
函数转换(transforms)。
首先,让我们设置导入并生成一些虚拟数据
from flax import nnx
import jax
x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))
差异#
Flax NNX 转换可以转换非纯函数,并进行更改和副作用: - Flax NNX 转换使您能够转换以 Flax NNX 图对象作为参数的函数 - 例如 nnx.Module
、nnx.Rngs
、nnx.Optimizer
等 - 即使是那些状态会被更改的对象。 - 相比之下,这些类型的对象在 JAX 转换中无法识别。
Flax NNX 函数式 API 提供了一种将图结构转换为 pytrees 并返回的方法。通过在每个函数边界执行此操作,您可以有效地将图结构与任何 JAX 转换一起使用,并以与函数式纯度一致的方式传播状态更新。
Flax NNX 自定义转换,例如 nnx.jit
和 nnx.grad
,只是删除了样板代码,因此代码看起来是有状态的。
以下是使用 nnx.jit
和 nnx.grad
转换与使用 jax.jit
和 jax.grad
转换的代码的比较示例。
请注意
Flax NNX 转换函数的函数签名可以直接接受
nnx.Linear
nnx.Module
实例,并对Module
进行有状态的更新。JAX 转换函数的函数签名只能接受 pytree 注册的
nnx.State
和nnx.GraphDef
对象,并且必须返回它们的更新副本以保持转换函数的纯度。
@nnx.jit
def train_step(model, x, y):
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, argnums=1)(graphdef, state)
model = nnx.merge(graphdef, state)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model)
graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)
混合使用 Flax NNX 和 JAX 转换#
Flax NNX 转换和 JAX 转换都可以混合在一起使用,只要您代码中的 JAX 转换函数是纯的,并且具有 JAX 可识别的有效参数类型。
@nnx.jit
def train_step(model, x, y):
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, 1)(*nnx.split(model))
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
model = nnx.merge(graphdef, state)
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model)
graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)