性能考量

性能考量#

目前,Flax nnx.jit 在纯 Python 中遍历对象图,这很慢并且会增加开销。为了解决这个问题,Flax 团队将开发一个名为 flaxlib 的 Rust 扩展,以加速 graph.py 中的一些遍历逻辑。这类似于 JAX 团队通过引入 jaxlib 来解决标准 JAX pytrees 的类似问题(请参考 Flax PR #4196 中的第一步)。

但是,有两点需要考虑:

异步调度#

benchmarks/nnx_simple_training.py 中,我们正在增加层宽度(每层的特征数量)并测量使用 nnx.jitjax.jit 训练的同一模型的总训练时间。

如下图所示,在达到某个模型大小之后,遍历所花费的时间会被异步调度完全吸收。当 Python 能够完成当前的 for 循环步骤,并到达下一个 train_step,而 JAX 尚未完成上一个 train_step 时,就会发生这种情况。

performance-graph

这意味着您只需担心小型模型的 nnx.jit 开销。如果您正在使用小型模型,请查看下一节,了解如何消除开销。

降低 Python 开销#

要消除 Python 开销,您可以将常规的 jax.jitnnx.splitnnx.merge 结合使用,以分阶段执行遍历逻辑。

要学习如何执行此操作,我们首先创建以下简单的 Model

from flax import nnx
import jax
import jax.numpy as jnp
import optax

class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

接下来,让我们创建一个使用 nnx.jittrain_step() 函数,该函数接收 modeloptimizermetrics,它们都是 Flax NNX 对象。

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)

@nnx.jit  # <== currently slow
def train_step(model, optimizer, metrics, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates
  metrics.update(loss=loss)

  return loss
  
for _ in range(10):
  x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
  loss = train_step(model, optimizer, metrics, x, y)

为了加速此过程,在开始训练循环之前,我们可以对 train_step() 的所有 Flax NNX 对象输入使用 nnx.split,以创建遍历速度更快的 graphdefstate pytree。

接下来,我们更改 train_step() 以接受 graphdefstate,并在 train_step() 的开头和结尾使用 nnx.mergennx.split,以便在对象及其 pytree 表示之间来回切换。即使 nnx.splitnnx.merge 很慢,也无关紧要,因为它们只会在跟踪期间运行一次。

完成此操作后,我们可以更改 train_step() 函数以使用 jax.jit 而不是 nnx.jit

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adamw(1e-3))  # reference sharing
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)
# split before training loop
graphdef, state = nnx.split((model, optimizer, metrics))

@jax.jit  # regular JAX
def train_step(graphdef, state, x, y):
  # merge at the beginning of the function
  model, optimizer, metrics = nnx.merge(graphdef, state)

  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)
  metrics.update(loss=loss)

  # split at the end of the function
  _, state = nnx.split((model, optimizer, metrics))

  # return new state
  return state, loss

for _ in range(10):
  x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
  state, loss = train_step(graphdef, state, x, y)

# update objects after training
nnx.update((model, optimizer, metrics), state)

请注意,我们仅对 jit 执行此操作。您仍然可以使用其他 Flax 转换,例如上面示例中显示的 nnx.value_and_grad,因为它们的开销已经被外部的 jit 吸收。

并且在训练循环完成后(或在需要时),我们可以使用 Flax nnx.update 将 Flax NNX 对象(如 modeloptimizermetrics)更新为新的 state