在多个设备上扩展 Flax 模块#

本指南展示了如何在多个设备和主机上使用 Flax 模块jax.jit(以前称为 experimental.pjit)和 flax.linen 扩展 Flax 模块

Flax 和 jax.jit 扩展#

jax.jit 遵循 单程序多数据 (SPMD) 范式,并自动编译您的代码以在多个设备上运行。您只需要指定您希望如何划分代码的输入和输出,编译器将找出如何:1) 划分内部的所有内容;以及 2) 编译设备间通信。

Flax 提供了一些功能,可以帮助您在 Flax 模块 上使用自动 SPMD,包括

  1. 在定义 flax.linen.Module 时,用于指定数据分区的接口。

  2. 用于生成 jax.jit 运行所需的碎片信息 的实用程序函数。

  3. 用于自定义轴名称的接口,称为“逻辑轴注释”,以将您的模块代码和分区计划解耦,以便更轻松地尝试不同的分区布局。

您可以在 JAX 的文档网站上的 多进程环境中的 JAX分布式数组和自动并行化 中了解有关 jax.jit 用于扩展的 API 的更多信息。

设置#

导入一些必要的依赖项。

注意:本指南使用 --xla_force_host_platform_device_count=8 标志在 Google Colab/Jupyter Notebook 中的 CPU 环境中模拟多个设备。如果您已经在使用多设备 TPU 环境,则不需要此标志。

# Once Flax v0.6.10 is released, there is no need to do this.
# ! pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax"
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import functools
from typing import Optional, Callable

import numpy as np
import jax
from jax import lax, random, numpy as jnp

import flax
from flax import struct, traverse_util, linen as nn
from flax.core import freeze, unfreeze
from flax.training import train_state, checkpoints

import optax # Optax for common losses and optimizers.
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
print(f'We have 8 fake JAX devices now: {jax.devices()}')
We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

以下代码展示了如何导入和设置 JAX 级别设备 API,遵循 JAX 的 分布式数组和自动并行化 指南

  1. 使用 JAX 的 mesh_utils.create_device_mesh 启动一个 2x4 设备 mesh(8 个设备)。此布局与 TPU v3-8 的布局相同。

  2. 使用 jax.sharding.Mesh 中的 axis_names 参数为每个轴添加一个名称。为轴名称添加注释的典型方法是 axis_name=('data', 'model'),其中

  • 'data':用于输入和激活的批量维度的 数据并行分片的网格维度。

  • 'model':用于在设备之间对模型参数进行分片的 网格维度。

  1. 使用一个简单的实用程序函数 mesh_sharding 从网格和任何布局生成一个分片对象。

from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
# Create a mesh and annotate each axis with a name.
device_mesh = mesh_utils.create_device_mesh((2, 4))
print(device_mesh)

mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)

def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
  return NamedSharding(mesh, pspec)
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
 [CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]
Mesh('data': 2, 'model': 4)

定义一个层#

在定义一个简单的模型之前,创建一个名为 DotReluDot 的示例层(通过子类化 flax.linen.Module)。该层为点积乘法创建了两个参数 W1W2,并在两者之间使用了 jax.nn.relu(ReLU)激活函数。

要有效地对参数进行分片,请应用以下 API 来注释参数和中间变量

  1. 使用 flax.linen.with_partitioning 装饰创建子层或原始参数时的初始化函数。

  2. 应用 jax.lax.with_sharding_constraint(以前为 pjit.with_sharding_constraint)注释中间变量,如 yz,以在已知理想约束时强制使用特定分片模式。

  • 此步骤是可选的,但有时可以帮助自动 SPMD 有效地进行分区。在下面的示例中,不需要调用,因为 XLA 将为 yz 找出相同的分片布局,无论如何。

class DotReluDot(nn.Module):
  depth: int
  dense_init: Callable = nn.initializers.xavier_normal()
  @nn.compact
  def __call__(self, x):

    y = nn.Dense(self.depth,
                 kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
                 use_bias=False,  # or overwrite with `bias_init`
                 )(x)

    y = jax.nn.relu(y)
    # Force a local sharding annotation.
    y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))

    W2 = self.param(
        'W2',
        nn.with_partitioning(self.dense_init, ('model', None)),
        (self.depth, x.shape[-1]))

    z = jnp.dot(y, W2)
    # Force a local sharding annotation.
    z = with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))

    # Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
    return z, None

请注意,设备轴名称,如 'data''model'None,将传递给 flax.linen.with_partitioningjax.lax.with_sharding_constraint API 调用。这指的是如何对数据的每个维度进行分片——是在设备网格维度之一中分片,还是根本不分片。

例如

  • 当您定义形状为 (x.shape[-1], self.depth)W1 并将其注释为 (None, 'model')

    • 第一个维度(长度为 x.shape[-1])将在所有设备上复制。

    • 第二个维度(长度为 self.depth)将在设备网格的 'model' 轴上进行分片。这意味着 W1 将在设备 (0, 4)(1, 5)(2, 6)(3, 7) 上进行 4 路分片,在这个维度上。

  • 当您将输出 z 注释为 ('data', None)

    • 第一个维度——批量维度——将在 'data' 轴上进行分片。这意味着一半的批量将在设备 0-3(前四个设备)上处理,另一半将在设备 4-7(剩余的四个设备)上处理。

    • 第二个维度——数据深度维度——将在所有设备上复制。

使用 flax.linen.scan 提升转换定义一个模型#

创建了 DotReluDot 之后,您现在可以定义 MLP 模型(通过子类化 flax.linen.Module)作为多个 DotReluDot 层。

要复制相同的层,您可以使用 flax.linen.scan 或 for 循环。

  • flax.linen.scan 可以提供更快的编译时间。

  • for 循环在运行时可能更快。

以下代码展示了如何应用两种方法,并默认使用 for 循环,以便所有参数都是二维的,您可以可视化它们的分片。

代码 flax.linen.scan 只是为了展示这个 API 如何与 Flax 提升的变换 一起使用。

class MLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool
  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(DotReluDot, length=self.num_layers,
                     variable_axes={"params": 0},
                     split_rngs={"params": True},
                     metadata_params={nn.PARTITION_NAME: None}
                     )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = DotReluDot(self.depth)(x)
    return x

现在,创建一个 model 实例,以及一个样本输入 x

# MLP hyperparameters.
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False
# Create fake inputs.
x = jnp.ones((BATCH, DEPTH))
# Initialize a PRNG key.
k = random.key(0)

# Create an Optax optimizer.
optimizer = optax.adam(learning_rate=0.001)
# Instantiate the model.
model = MLP(LAYERS, DEPTH, USE_SCAN)

指定分片#

接下来,你需要告诉 jax.jit 如何将我们的数据跨设备分片。

输入的分片#

对于数据并行,你可以通过将批次轴表示为 'data',将批处理的输入 xdata 轴分片。然后,使用 jax.device_put 将其放置在正确的 device 上。

x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length)
x = jax.device_put(x, x_sharding)
jax.debug.visualize_array_sharding(x)
┌──────────────────────────────────────────────────────────────────────────────┐
│                                                                              │
│                                 CPU 0,1,2,3                                  │
│                                                                              │
│                                                                              │
├──────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│                                 CPU 4,5,6,7                                  │
│                                                                              │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘

输出的分片#

你需要编译 model.init()(也就是 flax.linen.Module.init()),以及它的输出作为参数的 pytree。此外,你可能有时需要用 flax.training.train_state 来跟踪其他变量,例如优化器状态,这将使输出成为一个更复杂的 pytree。

幸运的是,要实现这一点,你无需手动硬编码输出的分片。相反,你可以

  1. 使用 jax.eval_shape 抽象地评估 model.init(在本例中,是一个包装器)。

  2. 使用 flax.linen.get_sharding 自动生成 jax.sharding.NamedSharding

def init_fn(k, x, model, optimizer):
  variables = model.init(k, x) # Initialize the model.
  state = train_state.TrainState.create( # Create a `TrainState`.
    apply_fn=model.apply,
    params=variables['params'],
    tx=optimizer)
  return state
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
abstract_variables = jax.eval_shape(
    functools.partial(init_fn, model=model, optimizer=optimizer), k, x)

# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
state_sharding = nn.get_sharding(abstract_variables, mesh)
state_sharding
TrainState(step=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec()), apply_fn=<bound method Module.apply of MLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = False
)>, params={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x33e134280>, update=<function chain.<locals>.update_fn at 0x33e134430>), opt_state=(ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec()), mu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}, nu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}), EmptyState()))

编译代码#

现在你可以将 jax.jit 应用于你的 init_fn,但有两个额外的参数:in_shardingsout_shardings

运行它以获取 initialized_state,其中参数按照指示进行分片。

jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
                      in_shardings=(mesh_sharding(()), x_sharding),  # PRNG key and x
                      out_shardings=state_sharding)

initialized_state = jit_init_fn(k, x, model, optimizer)

# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():
#     print(f'Sharding of {weight}: {partitioned.names}')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
┌───────────────────────┐
│        CPU 0,4        │
├───────────────────────┤
│        CPU 1,5        │
├───────────────────────┤
│        CPU 2,6        │
├───────────────────────┤
│        CPU 3,7        │
└───────────────────────┘

检查模块输出#

请注意,在 initialized_state 的输出中,params W1W2 的类型为 flax.linen.Partitioned。这是一个围绕实际 jax.Array 的包装器,它允许 Flax 记录与它关联的轴名称。

你可以通过对字典调用 flax.linen.meta.unbox(),或者对单个变量调用 .value 来访问原始的 jax.Array。你也可以使用 flax.linen.meta.replace_boxed() 来更改底层的 jax.Array,而不会修改分片注解。

print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)
<class 'flax.core.meta.Partitioned'>
<class 'jaxlib.xla_extension.ArrayImpl'>
(None, 'model')
(1024, 1024)
# Say for some unknown reason you want to make the whole param tree all-zero
unboxed_params = nn.meta.unbox(initialized_state.params)
all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0

你还可以检查每个参数的底层 jax.sharding,它现在比 NamedSharding 更内部。请注意,像 initialized_state.step 这样的数字在所有设备上都已复制。

initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))
print(initialized_state.step)
initialized_state.step.sharding
0
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec())

你可以使用 jax.tree_util.tree_map 对盒装参数字典进行批量计算,就像在 JAX 数组字典上一样。

diff = jax.tree_util.tree_map(
    lambda a, b: a - b,
    initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0'])
print(jax.tree_util.tree_map(jnp.shape, diff))
diff_array = diff['Dense_0']['kernel'].value
print(type(diff_array))
print(diff_array.shape)
{'Dense_0': {'kernel': Partitioned(value=(1024, 1024), names=(None, 'model'), mesh=None)}, 'W2': Partitioned(value=(1024, 1024), names=('model', None), mesh=None)}
<class 'jaxlib.xla_extension.ArrayImpl'>
(1024, 1024)

编译训练步骤和推理#

创建一个 jit 的训练步骤,如下所示

@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
                   out_shardings=state_sharding)
def train_step(state, x):
  # A fake loss function.
  def loss_unrolled(params):
    y = model.apply({'params': params}, x)
    return y.sum()
  grad_fn = jax.grad(loss_unrolled)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

with mesh:
  new_state = train_step(initialized_state, x)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
Sharding of Weight 1:
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
Sharding of Weight 2:
┌───────────────────────┐
│        CPU 0,4        │
├───────────────────────┤
│        CPU 1,5        │
├───────────────────────┤
│        CPU 2,6        │
├───────────────────────┤
│        CPU 3,7        │
└───────────────────────┘

然后,创建一个编译的推理步骤。请注意,输出也沿着 (data, None) 分片。

@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
                   out_shardings=x_sharding)
def apply_fn(state, x):
  return state.apply_fn({'params': state.params}, x)

with mesh:
  y = apply_fn(new_state, x)
print(type(y))
print(y.dtype)
print(y.shape)
jax.debug.visualize_array_sharding(y)
<class 'jaxlib.xla_extension.ArrayImpl'>
float32
(8, 1024)
┌──────────────────────────────────────────────────────────────────────────────┐
│                                                                              │
│                                 CPU 0,1,2,3                                  │
│                                                                              │
│                                                                              │
├──────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│                                 CPU 4,5,6,7                                  │
│                                                                              │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘

性能分析#

如果你在 TPU 集群或集群切片上运行,你可以使用一个自定义的 block_all 实用函数(如下定义)来衡量性能

%%timeit

def block_all(xs):
  jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(initialized_state, x))
20.9 ms ± 319 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

逻辑轴注解#

JAX 的自动 SPMD 鼓励用户探索不同的分片布局,以找到最佳的布局。为此,在 Flax 中,你实际上可以使用更具描述性的轴名称(不仅仅是像 'data''model' 这样的设备网格轴名称)来注解任何数据的维度。

下面的 LogicalDotReluDotLogicalMLP 模块定义与你之前创建的模块类似,除了以下内容

  1. 所有轴都用更具体、更有意义的名称进行注解,例如 'embed''hidden''batch''layer'。在 Flax 中,这些名称被称为逻辑轴名称。它们使模型定义中的维度变化更易读。

  2. flax.linen.with_logical_partitioning 替换了 flax.linen.with_partitioning;而 flax.linen.with_logical_constraint 替换了 jax.lax.with_sharding_constraint,以识别逻辑轴名称。

class LogicalDotReluDot(nn.Module):
  depth: int
  dense_init: Callable = nn.initializers.xavier_normal()
  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.depth,
                 kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
                 use_bias=False,  # or overwrite with `bias_init`
                 )(x)

    y = jax.nn.relu(y)
    # Force a local sharding annotation.
    y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))

    W2 = self.param(
        'W2',
        nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
        (self.depth, x.shape[-1]))

    z = jnp.dot(y, W2)
    # Force a local sharding annotation.
    z = nn.with_logical_constraint(z, ('batch', 'embed'))
    return z, None

class LogicalMLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool
  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
                    variable_axes={"params": 0},
                    split_rngs={"params": True},
                    metadata_params={nn.PARTITION_NAME: 'layer'}
                    )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = LogicalDotReluDot(self.depth)(x)
    return x

现在,初始化一个模型,并尝试找出其 state 应该具有哪些分片。

为了让设备网格正确地使用你的模型,你需要确定这些逻辑轴名称中的哪些映射到设备轴 'data''model'。此规则是一个 (logical_axis_namedevice_axis_name) 元组列表,而 flax.linen.logical_to_mesh_sharding 将它们转换为设备网格可以理解的分片类型。

这允许你更改规则,尝试新的分区布局,而无需修改模型定义。

# Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.
rules = (('batch', 'data'),
         ('hidden', 'model'))

logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)

logical_abstract_variables = jax.eval_shape(
    functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x)
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
print('annotations are logical, not mesh-specific: ',
      logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel'])

logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, rules)
print('sharding annotations are mesh-specific: ',
      logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)
annotations are logical, not mesh-specific:  PartitionSpec('embed', 'hidden')
sharding annotations are mesh-specific:  PartitionSpec(None, 'model')

你可以验证这里的 logical_state_spec 与之前(“非逻辑”)示例中的 state_spec 具有相同的内容。这允许你以相同的方式对你的模块的 flax.linen.Module.initflax.linen.Module.apply 使用 jax.jit

state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0']
True
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
                      in_shardings=(mesh_sharding(()), x_sharding),  # PRNG key and x
                      out_shardings=logical_state_sharding)

logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value)
Sharding of Weight 1:
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
Sharding of Weight 2:
┌───────────────────────┐
│        CPU 0,4        │
├───────────────────────┤
│        CPU 1,5        │
├───────────────────────┤
│        CPU 2,6        │
├───────────────────────┤
│        CPU 3,7        │
└───────────────────────┘

何时使用设备轴/逻辑轴#

选择何时使用设备轴或逻辑轴取决于你想要控制模型分区的程度

  • 设备网格轴:如果你想要一个非常简单的模型,或者你对你的分区方式非常有信心,使用设备网格轴定义它可能会为你节省几行将逻辑命名转换回设备命名的代码。

  • 逻辑命名:另一方面,逻辑命名助手对于探索不同的分片布局非常有用。如果你想要进行实验并为你的模型找到最优的分区布局,请使用此方法。

  • 设备轴名称:在非常高级的用例中,你可能拥有更复杂的分片模式,这些模式需要用与参数维度名称不同的方式来注解激活维度名称。如果你希望对手动网格分配有更细粒度的控制,直接使用设备轴名称可能更有帮助。

保存数据#

要保存跨设备数组,你可以使用 flax.training.checkpoints,如 保存和加载检查点指南 - 多主机/多进程检查点 中所示。这在你在多主机环境(例如 TPU 集群)中运行时尤其需要。

在实践中,你可能希望将原始的 jax.Array pytree 作为检查点保存,而不是包装的 Partitioned 值,以减少复杂性。你可以按原样恢复它,并使用 flax.linen.meta.replace_boxed() 将其放回带有注解的 pytree 中。

请记住,要将数组恢复到所需的分区,你需要提供一个样本 target pytree,该 pytree 具有相同的结构,并且每个 JAX 数组都具有所需的 jax.sharding.Sharding。你用来恢复数组的分片不必与你用来存储数组的分片相同。