在多个设备上扩展#

本指南演示了如何使用 JAX 即时编译机制 (jax.jit)flax.nnx.spmd 在 [多个设备和主机](Multi-host and multi-process environments)(例如 GPU、Google TPU 和 CPU)上扩展 Flax NNX Module

概述#

Flax 依赖于 JAX 进行数值计算,并在多个设备(如 GPU 和 Google TPU)上扩展计算。扩展的核心是 JAX 即时 (jax.jit) 编译器 jax.jit。在本指南中,您将使用 Flax 自己的 nnx.jit 转换,该转换包装了 jax.jit,并且更方便地与 Flax NNX Module 一起使用。

注意: 要了解有关 Flax 转换(如 nnx.jitnnx.vmap)的更多信息,请转到为什么选择 Flax NNX? - 转换转换Flax NNX 与 JAX 转换

JAX 编译遵循 单程序多数据 (SPMD) 范例。这意味着您编写 Python 代码时就好像它只在一个设备上运行,而 jax.jit自动编译在多个设备上运行

为了确保编译性能,您通常需要指示 JAX 如何在设备之间分片模型的变量。这就是 Flax NNX 的分片元数据 API - flax.nnx.spmd - 的用武之地。它可以帮助您使用此信息注释模型变量。

Flax Linen 用户注意flax.nnx.spmd API 类似于 (p)jit 指南中模型定义级别上描述的内容。但是,由于 Flax NNX 带来的好处,Flax NNX 中的顶层代码更简单,并且一些文本解释将更加更新和清晰。

如果您不熟悉 JAX 中的并行化,您可以在以下教程中了解有关其用于扩展的 API 的更多信息

设置#

导入一些必要的依赖项。

注意: 本指南使用 --xla_force_host_platform_device_count=8 标志在 Google Colab/Jupyter Notebook 的 CPU 环境中模拟多个设备。如果您已经在使用多设备 TPU 环境,则不需要此标志。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import *

import numpy as np
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

from flax import nnx

import optax # Optax for common losses and optimizers.
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
You have 8 “fake” JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

以下代码展示了如何导入和设置 JAX 级别的设备 API,遵循 JAX 的 分布式数组和自动并行化 指南

  1. 使用 JAX jax.sharding.Mesh 启动 2x4 设备 mesh(8 个设备)。此布局与 TPU v3-8(也是 8 个设备)上的布局相同。

  2. 使用 axis_names 参数使用名称注释每个轴。注释轴名称的典型方法是 axis_name=('data', 'model'),其中

  • 'data':用于输入和激活的批次维度的数据并行分片的网格维度。

  • 'model':用于在设备之间分片模型的参数的网格维度。

# Create a mesh of two dimensions and annotate each axis with a name.
mesh = Mesh(devices=np.array(jax.devices()).reshape(2, 4),
            axis_names=('data', 'model'))
print(mesh)
Mesh('data': 2, 'model': 4)

定义具有指定分片的模型#

接下来,创建一个名为 DotReluDot 的示例层,该层是 Flax nnx.Module 的子类。

  • 此层在输入 x 上执行两次点积乘法,并在其间使用 jax.nn.relu (ReLU) 激活函数。

  • 要使用其理想分片注释模型变量,可以使用 flax.nnx.with_partitioning 来包装其初始化函数。本质上,这会调用 flax.nnx.with_metadata,它会向相应的 nnx.Variable 添加一个 .sharding 属性字段。

注意: 此注解将在 Flax NNX 中提升的转换中保留并进行相应调整。这意味着,如果您将分片注解与任何修改轴的转换(如 nnx.vmapnnx.scan)一起使用,则需要通过 transform_metadata 参数提供该附加轴的分片。请查看Flax NNX 转换 (transforms) 指南以了解更多信息。

class DotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1` and annotate its kernel with.
    # `sharding (None, 'model')`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
      use_bias=False,  # or use `bias_init` to give it annotation too
      rngs=rngs)

    # Initialize a weight param `w2` and annotate with sharding ('model', None).
    # Note that this is simply adding `.sharding` to the variable as metadata!
    self.w2 = nnx.Param(
      init_fn(rngs.params(), (depth, depth)),  # RNG key and shape for W2 creation
      sharding=('model', None),
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # In data parallelism, input / intermediate value's first dimension (batch)
    # will be sharded on `data` axis
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model'))
    z = jnp.dot(y, self.w2.value)
    return z

理解分片名称#

所谓的“分片注解”本质上是设备轴名称的元组,如 'data''model'None。这描述了此 JAX 数组的每个维度应如何分片 — 要么跨设备网格维度之一分片,要么根本不分片。

因此,当您定义形状为 (depth, depth)W1 并将其注解为 (None, 'model')

  • 第一个维度将在所有设备上复制。

  • 第二个维度将在设备网格的 'model' 轴上分片。这意味着 W1 将在此维度中在设备 (0, 4)(1, 5)(2, 6)(3, 7) 上进行 4 路分片。

JAX 的分布式数组和自动并行化指南提供了更多示例和解释。

初始化分片模型#

现在,您已经将注解附加到 Flax nnx.Variable 上,但实际的权重尚未分片。如果您直接创建此模型,则所有 jax.Arrays 仍会卡在设备 0 上。实际上,您应该避免这种情况,因为大型模型在这种情况下会“OOM”(导致设备内存不足),而所有其他设备都未被利用。

unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))

# You have annotations stuck there, yay!
print(unsharded_model.dot1.kernel.sharding)     # (None, 'model')
print(unsharded_model.w2.sharding)              # ('model', None)

# But the actual arrays are not sharded?
print(unsharded_model.dot1.kernel.value.sharding)  # SingleDeviceSharding
print(unsharded_model.w2.value.sharding)           # SingleDeviceSharding
(None, 'model')
('model', None)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

在这里,您应该通过 Flax 的 nnx.jit 来利用 JAX 的编译机制来创建分片模型。关键是在 jit 函数内初始化模型并在模型状态上分配分片。

  1. 使用 nnx.get_partition_spec 来剥离附加在模型变量上的 .sharding 注解。

  2. 调用 jax.lax.with_sharding_constraint 将模型状态与分片注解绑定。此 API 告诉顶层 jit 如何对变量进行分片!

  3. 丢弃未分片的状态,并返回基于分片状态的模型。

  4. 使用 nnx.jit 编译整个函数,这允许输出成为有状态的 Flax NNX Module

  5. 在设备网格上下文中运行它,以便 JAX 知道将其分片到哪些设备。

整个编译后的 create_sharded_model() 函数将直接生成带有分片 JAX 数组的模型,并且不会发生单设备“OOM”!

@nnx.jit
def create_sharded_model():
  model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment.
  state = nnx.state(model)                   # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
  sharded_model = create_sharded_model()

# They are some `GSPMDSharding` now - not a single device!
print(sharded_model.dot1.kernel.value.sharding)
print(sharded_model.w2.value.sharding)

# Check out their equivalency with some easier-to-read sharding descriptions
assert sharded_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=unpinned_host)

您可以使用 jax.debug.visualize_array_sharding 查看任何一维或二维数组的分片情况

print("sharded_model.dot1.kernel (None, 'model') :")
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)
print("sharded_model.w2 ('model', None) :")
jax.debug.visualize_array_sharding(sharded_model.w2.value)
sharded_model.dot1.kernel (None, 'model') :
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
sharded_model.w2 ('model', None) :
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

关于 jax.lax.with_sharding_constraint(半自动并行化)#

分片 JAX 数组的关键是在 jax.jit 函数内部调用 jax.lax.with_sharding_constraint。请注意,如果不在 JAX 设备网格上下文中,它将抛出错误。

注意:JAX 文档中的并行编程简介分布式数组和自动并行化都更详细地介绍了使用 jax.jit 进行自动并行化,以及使用 jax.jit`jax.lax.with_sharding_constraint 进行半自动并行化。

您可能已经注意到,您还在模型定义中使用了 jax.lax.with_sharding_constraint 一次,以约束中间值的分片。这只是为了表明,如果您想显式地对非模型变量的值进行分片,则始终可以与 Flax NNX API 正交地使用它。

这带来了一个问题:那么为什么要使用 Flax NNX 注解 API 呢?为什么不直接在模型定义中添加 JAX 分片约束呢?最重要的原因是,您仍然需要显式注解才能从磁盘上的检查点加载分片模型。这将在下一节中介绍。

从检查点加载分片模型#

现在,您已经学习了如何初始化分片模型而不会出现 OOM,但是如何从磁盘上的检查点加载它呢?JAX 检查点库(例如 Orbax)通常支持在提供分片 pytree 时加载分片模型。

您可以使用 Flax 的 nnx.get_named_sharding 来生成这样的分片 pytree。为避免任何实际的内存分配,请使用 nnx.eval_shape 转换来生成抽象 JAX 数组的模型,并仅使用其 .sharding 注解来获取分片树。

下面是一个演示使用 Orbax 的 StandardCheckpointer API 的示例。(转到Orbax 文档站点以了解其最新和最推荐的 API。)

import orbax.checkpoint as ocp

# Save the sharded state.
sharded_state = nnx.state(sharded_model)
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path / 'checkpoint_name', sharded_state)

# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`.
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model)
# Orbax API expects a tree of abstract `jax.ShapeDtypeStruct`
# that contains both sharding and the shape/dtype of the arrays.
abs_state = jax.tree.map(
  lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
  abs_state, nnx.get_named_sharding(abs_state, mesh)
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
                                      target=abs_state)
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

编译训练循环#

现在,在初始化或加载检查点之后,您有一个分片模型。要执行编译后的扩展训练,您还需要对输入进行分片。

  • 在数据并行示例中,训练数据在其批处理维度上跨 data 设备轴进行分片,因此您应该将数据放在分片 ('data', None) 中。您可以使用 jax.device_put 来执行此操作。

  • 请注意,对于所有输入都使用正确的分片,即使没有 jit 编译,输出也将以最自然的方式进行分片。

  • 在下面的示例中,即使没有对输出 y 使用 jax.lax.with_sharding_constraint,它仍然被分片为 ('data', None)

如果您对原因感兴趣:DotReluDot.__call__ 的第二个 matmul 有两个输入,其分片分别为 ('data', 'model')('model', None),其中两个输入的收缩轴均为 model。因此,发生了 reduce-scatter matmul,并且自然会将输出分片为 ('data', None)。如果您想从数学上了解它是如何在底层发生的,请查看JAX 分片映射集体指南及其示例。

# In data parallelism, the first dimension (batch) will be sharded on the `data` axis.
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)

with mesh:
  output = sharded_model(input)
print(output.shape)
jax.debug.visualize_array_sharding(output)  # Also sharded as `('data', None)`.
(8, 1024)
                                                                                
                                                                                
                                  CPU 0,1,2,3                                   
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                  CPU 4,5,6,7                                   
                                                                                
                                                                                
                                                                                

现在,训练循环的其余部分非常传统 - 它几乎与Flax NNX Basics中的示例相同

  • 只是输入和标签也进行了显式分片。

  • nnx.jit 将根据其输入的已分片方式调整并自动选择最佳布局,因此请为自己的模型和输入尝试不同的分片。

optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3))  # reference sharing

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model: DotReluDot):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)

  return loss

input = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
label = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)

with mesh:
  for i in range(5):
    loss = train_step(sharded_model, optimizer, input, label)
    print(loss)    # Model (over-)fitting to the labels quickly.
1.455235
0.7646729
0.50971293
0.378493
0.28089797

分析#

如果您使用的是 Google TPU pod 或 pod 切片,则可以创建一个自定义的 block_all() 实用程序函数(如下定义)来衡量性能

%%timeit

def block_all(xs):
  jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(sharded_model, optimizer, input, label))
57.6 ms ± 569 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

逻辑轴注解#

JAX 的自动 SPMD 鼓励用户探索不同的分片布局以找到最佳布局。为此,在 Flax 中,您可以选择使用更多描述性轴名称进行注解(不仅仅是像 'data''model' 这样的设备网格轴名称),只要您提供从别名到设备网格轴的映射即可。

您可以将映射以及注解作为相应 nnx.Variable 的另一个元数据提供,或者在顶层覆盖它。请查看下面的 LogicalDotReluDot() 示例。

# The mapping from alias annotation to the device mesh.
sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))

class LogicalDotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_metadata(
        # Provide the sharding rules here.
        init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules),
      use_bias=False,
      rngs=rngs)

    # Initialize a weight param `w2`.
    self.w2 = nnx.Param(
      # Didn't provide the sharding rules here to show you how to overwrite it later.
      nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))(
        rngs.params(), (depth, depth))
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # Unfortunately the logical aliasing doesn't work on lower-level JAX calls.
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None))
    z = jnp.dot(y, self.w2.value)
    return z

如果您没有在模型定义中提供所有 sharding_rule 注解,则可以编写几行代码将其添加到 Flax 的模型 nnx.State 中,然后在调用 nnx.get_partition_specnnx.get_named_sharding 之前。

def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState:
  vs.sharding_rules = sharding_rules
  return vs

@nnx.jit
def create_sharded_logical_model():
  model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0))
  state = nnx.state(model)
  state = jax.tree.map(add_sharding_rule, state,
                       is_leaf=lambda x: isinstance(x, nnx.VariableState))
  pspecs = nnx.get_partition_spec(state)
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)
  return model

with mesh:
  sharded_logical_model = create_sharded_logical_model()

jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value)
jax.debug.visualize_array_sharding(sharded_logical_model.w2.value)

# Check out their equivalency with some easier-to-read sharding descriptions.
assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_logical_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)

with mesh:
  logical_output = sharded_logical_model(input)
  assert logical_output.sharding.is_equivalent_to(
    NamedSharding(mesh, PartitionSpec('data', None)), ndim=2
  )
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

何时使用设备轴/逻辑轴#

何时使用设备轴或逻辑轴取决于您想对模型的分区进行多少控制

  • 设备网格轴:

    • 对于更简单的模型,这可以为您节省一些将逻辑命名转换回设备命名的额外代码行。

    • 中间激活值的分片只能通过 jax.lax.with_sharding_constraint 和设备网格轴来实现。因此,如果您想对模型的分片进行超精细的控制,直接在各处使用设备网格轴名称可能会更清晰易懂。

  • 逻辑命名:如果您想进行实验并找到模型权重的最佳分区布局,这将很有帮助。