在多个设备上扩展 Flax 模块#
本指南展示了如何在多个设备和主机上使用 Flax 模块 和 jax.jit
(以前称为 experimental.pjit
)和 flax.linen
扩展 Flax 模块。
Flax 和 jax.jit
扩展#
jax.jit
遵循 单程序多数据 (SPMD) 范式,并自动编译您的代码以在多个设备上运行。您只需要指定您希望如何划分代码的输入和输出,编译器将找出如何:1) 划分内部的所有内容;以及 2) 编译设备间通信。
Flax 提供了一些功能,可以帮助您在 Flax 模块 上使用自动 SPMD,包括
在定义
flax.linen.Module
时,用于指定数据分区的接口。用于生成
jax.jit
运行所需的碎片信息 的实用程序函数。用于自定义轴名称的接口,称为“逻辑轴注释”,以将您的模块代码和分区计划解耦,以便更轻松地尝试不同的分区布局。
您可以在 JAX 的文档网站上的 多进程环境中的 JAX 和 分布式数组和自动并行化 中了解有关 jax.jit
用于扩展的 API 的更多信息。
设置#
导入一些必要的依赖项。
注意:本指南使用 --xla_force_host_platform_device_count=8
标志在 Google Colab/Jupyter Notebook 中的 CPU 环境中模拟多个设备。如果您已经在使用多设备 TPU 环境,则不需要此标志。
# Once Flax v0.6.10 is released, there is no need to do this.
# ! pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax"
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import functools
from typing import Optional, Callable
import numpy as np
import jax
from jax import lax, random, numpy as jnp
import flax
from flax import struct, traverse_util, linen as nn
from flax.core import freeze, unfreeze
from flax.training import train_state, checkpoints
import optax # Optax for common losses and optimizers.
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
print(f'We have 8 fake JAX devices now: {jax.devices()}')
We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
以下代码展示了如何导入和设置 JAX 级别设备 API,遵循 JAX 的 分布式数组和自动并行化 指南
使用 JAX 的
mesh_utils.create_device_mesh
启动一个 2x4 设备mesh
(8 个设备)。此布局与 TPU v3-8 的布局相同。使用
jax.sharding.Mesh
中的axis_names
参数为每个轴添加一个名称。为轴名称添加注释的典型方法是axis_name=('data', 'model')
,其中
'data'
:用于输入和激活的批量维度的 数据并行分片的网格维度。'model'
:用于在设备之间对模型参数进行分片的 网格维度。
使用一个简单的实用程序函数
mesh_sharding
从网格和任何布局生成一个分片对象。
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
# Create a mesh and annotate each axis with a name.
device_mesh = mesh_utils.create_device_mesh((2, 4))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec)
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
[CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]
Mesh('data': 2, 'model': 4)
定义一个层#
在定义一个简单的模型之前,创建一个名为 DotReluDot
的示例层(通过子类化 flax.linen.Module
)。该层为点积乘法创建了两个参数 W1
和 W2
,并在两者之间使用了 jax.nn.relu
(ReLU)激活函数。
要有效地对参数进行分片,请应用以下 API 来注释参数和中间变量
使用
flax.linen.with_partitioning
装饰创建子层或原始参数时的初始化函数。应用
jax.lax.with_sharding_constraint
(以前为pjit.with_sharding_constraint
)注释中间变量,如y
和z
,以在已知理想约束时强制使用特定分片模式。
此步骤是可选的,但有时可以帮助自动 SPMD 有效地进行分区。在下面的示例中,不需要调用,因为 XLA 将为
y
和z
找出相同的分片布局,无论如何。
class DotReluDot(nn.Module):
depth: int
dense_init: Callable = nn.initializers.xavier_normal()
@nn.compact
def __call__(self, x):
y = nn.Dense(self.depth,
kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
use_bias=False, # or overwrite with `bias_init`
)(x)
y = jax.nn.relu(y)
# Force a local sharding annotation.
y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
W2 = self.param(
'W2',
nn.with_partitioning(self.dense_init, ('model', None)),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))
# Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
return z, None
请注意,设备轴名称,如 'data'
、'model'
或 None
,将传递给 flax.linen.with_partitioning
和 jax.lax.with_sharding_constraint
API 调用。这指的是如何对数据的每个维度进行分片——是在设备网格维度之一中分片,还是根本不分片。
例如
当您定义形状为
(x.shape[-1], self.depth)
的W1
并将其注释为(None, 'model')
时第一个维度(长度为
x.shape[-1]
)将在所有设备上复制。第二个维度(长度为
self.depth
)将在设备网格的'model'
轴上进行分片。这意味着W1
将在设备(0, 4)
、(1, 5)
、(2, 6)
和(3, 7)
上进行 4 路分片,在这个维度上。
当您将输出
z
注释为('data', None)
时第一个维度——批量维度——将在
'data'
轴上进行分片。这意味着一半的批量将在设备0-3
(前四个设备)上处理,另一半将在设备4-7
(剩余的四个设备)上处理。第二个维度——数据深度维度——将在所有设备上复制。
使用 flax.linen.scan
提升转换定义一个模型#
创建了 DotReluDot
之后,您现在可以定义 MLP
模型(通过子类化 flax.linen.Module
)作为多个 DotReluDot
层。
要复制相同的层,您可以使用 flax.linen.scan
或 for 循环。
flax.linen.scan
可以提供更快的编译时间。for 循环在运行时可能更快。
以下代码展示了如何应用两种方法,并默认使用 for 循环,以便所有参数都是二维的,您可以可视化它们的分片。
代码 flax.linen.scan
只是为了展示这个 API 如何与 Flax 提升的变换 一起使用。
class MLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(DotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: None}
)(self.depth)(x)
else:
for i in range(self.num_layers):
x, _ = DotReluDot(self.depth)(x)
return x
现在,创建一个 model
实例,以及一个样本输入 x
。
# MLP hyperparameters.
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False
# Create fake inputs.
x = jnp.ones((BATCH, DEPTH))
# Initialize a PRNG key.
k = random.key(0)
# Create an Optax optimizer.
optimizer = optax.adam(learning_rate=0.001)
# Instantiate the model.
model = MLP(LAYERS, DEPTH, USE_SCAN)
指定分片#
接下来,你需要告诉 jax.jit
如何将我们的数据跨设备分片。
输入的分片#
对于数据并行,你可以通过将批次轴表示为 'data'
,将批处理的输入 x
跨 data
轴分片。然后,使用 jax.device_put
将其放置在正确的 device
上。
x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length)
x = jax.device_put(x, x_sharding)
jax.debug.visualize_array_sharding(x)
┌──────────────────────────────────────────────────────────────────────────────┐ │ │ │ CPU 0,1,2,3 │ │ │ │ │ ├──────────────────────────────────────────────────────────────────────────────┤ │ │ │ CPU 4,5,6,7 │ │ │ │ │ └──────────────────────────────────────────────────────────────────────────────┘
输出的分片#
你需要编译 model.init()
(也就是 flax.linen.Module.init()
),以及它的输出作为参数的 pytree。此外,你可能有时需要用 flax.training.train_state
来跟踪其他变量,例如优化器状态,这将使输出成为一个更复杂的 pytree。
幸运的是,要实现这一点,你无需手动硬编码输出的分片。相反,你可以
使用
jax.eval_shape
抽象地评估model.init
(在本例中,是一个包装器)。使用
flax.linen.get_sharding
自动生成jax.sharding.NamedSharding
。此步骤利用了之前定义中的
flax.linen.with_partitioning
注解,为参数生成正确的分片。
def init_fn(k, x, model, optimizer):
variables = model.init(k, x) # Initialize the model.
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=variables['params'],
tx=optimizer)
return state
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=optimizer), k, x)
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
state_sharding = nn.get_sharding(abstract_variables, mesh)
state_sharding
TrainState(step=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec()), apply_fn=<bound method Module.apply of MLP(
# attributes
num_layers = 4
depth = 1024
use_scan = False
)>, params={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x33e134280>, update=<function chain.<locals>.update_fn at 0x33e134430>), opt_state=(ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec()), mu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}, nu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}), EmptyState()))
编译代码#
现在你可以将 jax.jit
应用于你的 init_fn
,但有两个额外的参数:in_shardings
和 out_shardings
。
运行它以获取 initialized_state
,其中参数按照指示进行分片。
jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x
out_shardings=state_sharding)
initialized_state = jit_init_fn(k, x, model, optimizer)
# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():
# print(f'Sharding of {weight}: {partitioned.names}')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
┌───────────────────────┐ │ CPU 0,4 │ ├───────────────────────┤ │ CPU 1,5 │ ├───────────────────────┤ │ CPU 2,6 │ ├───────────────────────┤ │ CPU 3,7 │ └───────────────────────┘
检查模块输出#
请注意,在 initialized_state
的输出中,params
W1
和 W2
的类型为 flax.linen.Partitioned
。这是一个围绕实际 jax.Array
的包装器,它允许 Flax 记录与它关联的轴名称。
你可以通过对字典调用 flax.linen.meta.unbox()
,或者对单个变量调用 .value
来访问原始的 jax.Array
。你也可以使用 flax.linen.meta.replace_boxed()
来更改底层的 jax.Array
,而不会修改分片注解。
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)
<class 'flax.core.meta.Partitioned'>
<class 'jaxlib.xla_extension.ArrayImpl'>
(None, 'model')
(1024, 1024)
# Say for some unknown reason you want to make the whole param tree all-zero
unboxed_params = nn.meta.unbox(initialized_state.params)
all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0
你还可以检查每个参数的底层 jax.sharding
,它现在比 NamedSharding
更内部。请注意,像 initialized_state.step
这样的数字在所有设备上都已复制。
initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))
print(initialized_state.step)
initialized_state.step.sharding
0
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec())
你可以使用 jax.tree_util.tree_map
对盒装参数字典进行批量计算,就像在 JAX 数组字典上一样。
diff = jax.tree_util.tree_map(
lambda a, b: a - b,
initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0'])
print(jax.tree_util.tree_map(jnp.shape, diff))
diff_array = diff['Dense_0']['kernel'].value
print(type(diff_array))
print(diff_array.shape)
{'Dense_0': {'kernel': Partitioned(value=(1024, 1024), names=(None, 'model'), mesh=None)}, 'W2': Partitioned(value=(1024, 1024), names=('model', None), mesh=None)}
<class 'jaxlib.xla_extension.ArrayImpl'>
(1024, 1024)
编译训练步骤和推理#
创建一个 jit
的训练步骤,如下所示
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
out_shardings=state_sharding)
def train_step(state, x):
# A fake loss function.
def loss_unrolled(params):
y = model.apply({'params': params}, x)
return y.sum()
grad_fn = jax.grad(loss_unrolled)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
with mesh:
new_state = train_step(initialized_state, x)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
Sharding of Weight 1:
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
Sharding of Weight 2:
┌───────────────────────┐ │ CPU 0,4 │ ├───────────────────────┤ │ CPU 1,5 │ ├───────────────────────┤ │ CPU 2,6 │ ├───────────────────────┤ │ CPU 3,7 │ └───────────────────────┘
然后,创建一个编译的推理步骤。请注意,输出也沿着 (data, None)
分片。
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
out_shardings=x_sharding)
def apply_fn(state, x):
return state.apply_fn({'params': state.params}, x)
with mesh:
y = apply_fn(new_state, x)
print(type(y))
print(y.dtype)
print(y.shape)
jax.debug.visualize_array_sharding(y)
<class 'jaxlib.xla_extension.ArrayImpl'>
float32
(8, 1024)
┌──────────────────────────────────────────────────────────────────────────────┐ │ │ │ CPU 0,1,2,3 │ │ │ │ │ ├──────────────────────────────────────────────────────────────────────────────┤ │ │ │ CPU 4,5,6,7 │ │ │ │ │ └──────────────────────────────────────────────────────────────────────────────┘
性能分析#
如果你在 TPU 集群或集群切片上运行,你可以使用一个自定义的 block_all
实用函数(如下定义)来衡量性能
%%timeit
def block_all(xs):
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
return xs
with mesh:
new_state = block_all(train_step(initialized_state, x))
20.9 ms ± 319 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
逻辑轴注解#
JAX 的自动 SPMD 鼓励用户探索不同的分片布局,以找到最佳的布局。为此,在 Flax 中,你实际上可以使用更具描述性的轴名称(不仅仅是像 'data'
和 'model'
这样的设备网格轴名称)来注解任何数据的维度。
下面的 LogicalDotReluDot
和 LogicalMLP
模块定义与你之前创建的模块类似,除了以下内容
所有轴都用更具体、更有意义的名称进行注解,例如
'embed'
、'hidden'
、'batch'
和'layer'
。在 Flax 中,这些名称被称为逻辑轴名称。它们使模型定义中的维度变化更易读。flax.linen.with_logical_partitioning
替换了flax.linen.with_partitioning
;而flax.linen.with_logical_constraint
替换了jax.lax.with_sharding_constraint
,以识别逻辑轴名称。
class LogicalDotReluDot(nn.Module):
depth: int
dense_init: Callable = nn.initializers.xavier_normal()
@nn.compact
def __call__(self, x):
y = nn.Dense(self.depth,
kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
use_bias=False, # or overwrite with `bias_init`
)(x)
y = jax.nn.relu(y)
# Force a local sharding annotation.
y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
W2 = self.param(
'W2',
nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = nn.with_logical_constraint(z, ('batch', 'embed'))
return z, None
class LogicalMLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: 'layer'}
)(self.depth)(x)
else:
for i in range(self.num_layers):
x, _ = LogicalDotReluDot(self.depth)(x)
return x
现在,初始化一个模型,并尝试找出其 state
应该具有哪些分片。
为了让设备网格正确地使用你的模型,你需要确定这些逻辑轴名称中的哪些映射到设备轴 'data'
或 'model'
。此规则是一个 (logical_axis_name
、device_axis_name
) 元组列表,而 flax.linen.logical_to_mesh_sharding
将它们转换为设备网格可以理解的分片类型。
这允许你更改规则,尝试新的分区布局,而无需修改模型定义。
# Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.
rules = (('batch', 'data'),
('hidden', 'model'))
logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)
logical_abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x)
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
print('annotations are logical, not mesh-specific: ',
logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel'])
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, rules)
print('sharding annotations are mesh-specific: ',
logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)
annotations are logical, not mesh-specific: PartitionSpec('embed', 'hidden')
sharding annotations are mesh-specific: PartitionSpec(None, 'model')
你可以验证这里的 logical_state_spec
与之前(“非逻辑”)示例中的 state_spec
具有相同的内容。这允许你以相同的方式对你的模块的 flax.linen.Module.init
和 flax.linen.Module.apply
使用 jax.jit
。
state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0']
True
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x
out_shardings=logical_state_sharding)
logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value)
Sharding of Weight 1:
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
Sharding of Weight 2:
┌───────────────────────┐ │ CPU 0,4 │ ├───────────────────────┤ │ CPU 1,5 │ ├───────────────────────┤ │ CPU 2,6 │ ├───────────────────────┤ │ CPU 3,7 │ └───────────────────────┘
何时使用设备轴/逻辑轴#
选择何时使用设备轴或逻辑轴取决于你想要控制模型分区的程度
设备网格轴:如果你想要一个非常简单的模型,或者你对你的分区方式非常有信心,使用设备网格轴定义它可能会为你节省几行将逻辑命名转换回设备命名的代码。
逻辑命名:另一方面,逻辑命名助手对于探索不同的分片布局非常有用。如果你想要进行实验并为你的模型找到最优的分区布局,请使用此方法。
设备轴名称:在非常高级的用例中,你可能拥有更复杂的分片模式,这些模式需要用与参数维度名称不同的方式来注解激活维度名称。如果你希望对手动网格分配有更细粒度的控制,直接使用设备轴名称可能更有帮助。
保存数据#
要保存跨设备数组,你可以使用 flax.training.checkpoints
,如 保存和加载检查点指南 - 多主机/多进程检查点 中所示。这在你在多主机环境(例如 TPU 集群)中运行时尤其需要。
在实践中,你可能希望将原始的 jax.Array
pytree 作为检查点保存,而不是包装的 Partitioned
值,以减少复杂性。你可以按原样恢复它,并使用 flax.linen.meta.replace_boxed()
将其放回带有注解的 pytree 中。
请记住,要将数组恢复到所需的分区,你需要提供一个样本 target
pytree,该 pytree 具有相同的结构,并且每个 JAX 数组都具有所需的 jax.sharding.Sharding
。你用来恢复数组的分片不必与你用来存储数组的分片相同。