保存和加载检查点#

本指南演示了如何使用 Orbax 保存和加载 Flax 检查点。

Orbax 提供了多种用于保存和加载模型数据的功能,您将在本文档中了解到这些功能。

  • 支持各种数组类型和存储格式

  • 异步保存以减少训练等待时间

  • 过去检查点的版本控制和自动簿记

  • 灵活的 transformations 用于调整和加载旧检查点

  • jax.sharding 基于 API 在多主机场景中保存和加载


正在进行的迁移到 Orbax

2023 年 7 月 30 日之后,Flax 的旧版 flax.training.checkpoints API 将被弃用,转而使用 Orbax

  • 如果您是新的 Flax 用户:请使用新的 orbax.checkpoint API,如本指南中所示。

  • 如果您在项目中使用了旧版 flax.training.checkpoints 代码:请考虑以下选项。

    • 将代码迁移到 Orbax(推荐):按照此 迁移指南 将 API 调用迁移到 orbax.checkpoint API。

    • 自动使用 Orbax 后端:在项目中添加 flax.config.update('flax_use_orbax_checkpointing', True),这将使您的 flax.training.checkpoints 调用自动使用 Orbax 后端来保存您的检查点。

      • 计划的翻转:这将成为2023 年 5 月(暂定日期)之后默认模式。

      • 如果您在自动迁移过程中遇到任何问题,请访问 Orbax 作为后端故障排除部分


为了向后兼容,本指南展示了 Flax 旧版 flax.training.checkpoints API 中的 Orbax 等效调用。

如果您需要了解有关 orbax.checkpoint 的更多信息,请参阅 Orbax 文档

设置#

安装/升级 Flax 和 Orbax。有关支持 GPU/TPU 的 JAX 安装,请访问 GitHub 上的此部分

注意:在运行 import jax 之前,创建八个假设备以模拟此笔记本中的 多主机环境。请注意,此处的导入顺序很重要。命令 os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' 仅适用于 CPU 后端,这意味着它不适用于 GPU/TPU 加速,如果您在 Google Colab 中运行此笔记本,则也不适用。如果您已经在多个设备上运行代码(例如,在 4x2 TPU 环境中),则可以跳过运行下一个单元格。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import Optional, Any
import shutil

import numpy as np
import jax
from jax import random, numpy as jnp

import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint

import optax
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
ckpt_dir = '/tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

保存检查点#

在 Orbax 和 Flax 中,您可以保存和加载任何给定的 JAX pytree。这不仅包括典型的 Python 和 NumPy 容器,还包括从 flax.struct.dataclass 扩展的自定义类。这意味着您可以存储几乎所有生成的数据——不仅是模型参数,还有任何数组/字典、元数据/配置等等。

首先,创建一个包含许多数据结构和容器的 pytree,并使用它。

# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
ckpt
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695322343.254588       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

使用 Orbax#

使用 orbax.checkpoint.PyTreeCheckpointer 将检查点直接保存到 tmp/orbax/single_save 目录。

注意:提供了一个可选的 save_args。建议使用此参数以提高性能速度,因为它将 pytree 中的较小数组捆绑到单个大文件,而不是多个较小文件。

from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)

接下来,要使用版本控制和自动簿记功能,您需要将 orbax.checkpoint.CheckpointManager 包装在 orbax.checkpoint.PyTreeCheckpointer 之上。

此外,请提供 orbax.checkpoint.CheckpointManagerOptions,它可以自定义您的需求,例如您希望以什么频率以及根据什么标准删除旧检查点。有关提供的选项的完整列表,请参阅 文档

orbax.checkpoint.CheckpointManager 应该放置在训练步骤之外的顶层以管理您的保存操作。

options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)

# Inside a training loop
for step in range(5):
    # ... do your training
    checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})

os.listdir('/tmp/flax_ckpt/orbax/managed')  # Because max_to_keep=2, only step 3 and 4 are retained
['4', '3']

使用旧 API#

以下是使用旧版 Flax 检查点实用程序保存的方法(请注意,与 orbax.checkpoint.CheckpointManagerOptions 相比,此方法提供的管理功能更少)。

# Import Flax Checkpoints.
from flax.training import checkpoints

checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=0,
                            overwrite=True,
                            keep=2)
'/tmp/flax_ckpt/flax-checkpointing/checkpoint_0'

恢复检查点#

使用 Orbax#

在 Orbax 中,调用 .restore() 以获取 orbax.checkpoint.PyTreeCheckpointerorbax.checkpoint.CheckpointManager,以原始 pytree 格式恢复您的检查点。

raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save')
raw_restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': {'opt_state': [None, None],
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

请注意,CheckpointManger 需要 step 编号。您还可以使用 .latest_step() 查找可用的最新步骤。

step = checkpoint_manager.latest_step()  # step = 4
checkpoint_manager.restore(step)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': {'opt_state': [None, None],
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

使用旧 API#

请注意,随着迁移到 Orbax 的进行,flax.training.checkpointing.restore_checkpoint 可以自动识别检查点是使用旧版 Flax 格式保存还是使用 Orbax 后端保存,并正确恢复 pytree。因此,添加 flax.config.update('flax_use_orbax_checkpointing', True) 不会影响您恢复旧检查点的能力。

以下是使用旧版 API 恢复检查点的方法。

raw_restored = checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=None)
raw_restored
{'config': {'dimensions': array([5, 3])},
 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)},
 'model': {'opt_state': {'0': None, '1': None},
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

使用自定义数据类恢复#

使用 Orbax#

  • 在上一个示例中恢复的 pytree 采用原始字典的形式。原始 pytree 包含自定义数据类,例如 TrainStateoptax 状态。

  • 这是因为在恢复 pytree 时,程序还不知道它曾经属于哪个结构。

  • 要解决此问题,您应该首先提供一个示例 pytree,让 Orbax 或 Flax 知道要恢复哪个结构。

本部分演示了如何显式设置任何自定义 Flax 数据类,使其与保存的检查点具有相同的结构。

注意:曾经是 JAX NumPy 数组 (jnp.array) 格式的数据将作为 NumPy 数组 (numpy.array) 恢复。这不会影响您的工作,因为 JAX 会 自动将 NumPy 数组转换为 JAX 数组,一旦计算开始。

empty_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=jax.tree_util.tree_map(np.zeros_like, variables['params']),  # values of the tree leaf doesn't matter
    tx=tx,
)
empty_config = {'dimensions': np.array([0, 0])}
target = {'model': empty_state, 'config': empty_config, 'data': [jnp.zeros_like(x1)]}
state_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save', item=target)
state_restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

使用旧 API#

或者,您可以从 Orbax CheckpointManager 和遗留的 Flax 代码中恢复,如下所示

checkpoint_manager.restore(4, items=target)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}
checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=target)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

通常建议将初始化检查点结构的过程(例如,TrainState)重构出来,以便保存/加载更容易且不易出错。这是因为像 apply_fntx(优化器)这样的函数和复杂对象无法序列化到检查点文件中,必须由代码初始化。

检查点结构不同时的恢复#

在开发过程中,更改模型、调整过程中添加/删除字段等操作会改变检查点结构。

本节介绍如何将旧数据加载到新代码中。

下面是一个简单的例子——一个从 flax.training.train_state.TrainState 扩展的 CustomTrainState,它包含一个名为 batch_stats 的额外字段。在处理实际模型时,在应用 批量归一化 时可能需要此字段。

这里,您将新的 CustomTrainState 存储为第 5 步,而第 4 步包含旧的/之前的 TrainState

class CustomTrainState(train_state.TrainState):
    batch_stats: Any = None

custom_state = CustomTrainState.create(
    apply_fn=state.apply_fn,
    params=state.params,
    tx=state.tx,
    batch_stats=np.arange(10),
)

custom_ckpt = {'model': custom_state, 'config': config, 'data': [x1]}
# Use a custom state to read the old `TrainState` checkpoint.
custom_target = {'model': custom_state, 'config': None, 'data': [jnp.zeros_like(x1)]}

# Save it in Orbax.
custom_save_args = orbax_utils.save_args_from_target(custom_ckpt)
checkpoint_manager.save(5, custom_ckpt, save_kwargs={'save_args': custom_save_args})
True

建议将您的检查点与您的 pytree 数据类定义保持同步。但是,您可能被迫在运行时恢复与不兼容的引用对象。发生这种情况时,检查点恢复将尝试在给定引用时尊重其结构。

以下是几种常见场景的示例。

场景 1:引用对象为部分对象#

如果您的引用对象是检查点的子树,则恢复将忽略额外的字段并恢复与引用具有相同结构的检查点。

就像下面的例子一样,CustomTrainState 中的 batch_stats 字段被忽略,检查点被恢复为 TrainState

这对于仅读取检查点的一部分也很有用。

restored = checkpoint_manager.restore(5, items=target)
assert not hasattr(restored, 'batch_stats')
assert type(restored['model']) == train_state.TrainState
restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=0, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

场景 2:检查点为部分对象#

另一方面,如果引用对象包含检查点中不可用的值,则检查点代码默认会警告某些数据不兼容。

要绕过错误,您需要传递一个 Orbax transform,它告诉 Orbax 如何将此检查点转换为 custom_target 的结构。

在这种情况下,传递一个默认的 {},让 Orbax 使用 custom_target 中的值来填充空白。这允许您将旧检查点恢复到新的数据结构 CustomTrainState 中。

try:
    checkpoint_manager.restore(4, items=custom_target)
except KeyError as e:
    print(f'KeyError when target state has an unmentioned field: {e}')
    print('')

# Step 4 is an original `TrainState`, without the `batch_stats`
custom_restore_args = orbax_utils.restore_args_from_target(custom_target)
restored = checkpoint_manager.restore(4, items=custom_target,
                                      restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args})
assert type(restored['model']) == CustomTrainState
np.testing.assert_equal(restored['model'].batch_stats,
                        custom_target['model'].batch_stats)
restored
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
KeyError when target state has an unmentioned field: 'batch_stats'
{'config': None,
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)],
 'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}

使用 Orbax#

如果您已经使用 Orbax 后端保存了您的检查点,则可以使用 orbax_transforms 在 Flax API 中访问此 transforms 参数。

# Save in the "Flax-with-Orbax" backend.
flax.config.update('flax_use_orbax_checkpointing', True)
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=4,
                            overwrite=True,
                            keep=2)

checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=custom_target, step=4,
                               orbax_transforms={})
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),
 'config': None,
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

使用遗留 API#

使用遗留的 flax.training.checkpoints API,也可以做类似的事情,但它们不像 Orbax Transformations 那样灵活。

您需要使用 target=None 将检查点恢复为原始字典,相应地修改结构,然后将其反序列化回原始目标。

# Save using the legacy Flax `checkpoints` API.
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=5,
                            overwrite=True,
                            keep=2)

# Pass no target to get a raw state dictionary first.
raw_state_dict = checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=None, step=5)
# Add/remove fields as needed.
raw_state_dict['model']['batch_stats'] = np.flip(np.arange(10))
# Restore the classes with correct target now
flax.serialization.from_state_dict(custom_target, raw_state_dict)
{'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),
 'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)]}

异步检查点#

检查点 I/O 繁重,如果您要保存大量数据,将它放到后台线程中可能会有所帮助,同时继续训练。

您可以通过创建一个 orbax.checkpoint.AsyncCheckpointer 来代替 orbax.checkpoint.PyTreeCheckpointer

注意:您应该使用同一个 async_checkpointer 来处理训练步骤中的所有异步保存,这样它就可以确保在下一个保存开始之前完成上一个异步保存。这使得簿记(如 keep(检查点的数量)和 overwrite)在各个步骤中保持一致。

每当您想要明确地等待异步保存完成时,可以调用 async_checkpointer.wait_until_finished()

# `orbax.checkpoint.AsyncCheckpointer` needs some multi-process initialization, because it was
# originally designed for multi-process large model checkpointing.
# For Python notebooks or other single-process settings, just set up with `num_processes=1`.
# Refer to https://jax.ac.cn/en/latest/multi_process.html#initializing-the-cluster
# for how to set it up in multi-process scenarios.
jax.distributed.initialize("localhost:8889", num_processes=1, process_id=0)

async_checkpointer = orbax.checkpoint.AsyncCheckpointer(
    orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)

# Save your job:
async_checkpointer.save('/tmp/flax_ckpt/orbax/single_save_async', ckpt, save_args=save_args)
# ... Continue with your work...

# ... Until a time when you want to wait until the save completes:
async_checkpointer.wait_until_finished()  # Blocks until the checkpoint saving is completed.
async_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save_async', item=target)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

如果您使用的是 Orbax CheckpointManager,只需在初始化它时传入 async_checkpointer。然后,在实践中,调用 async_checkpoint_manager.wait_until_finished() 即可。

async_checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed_async', async_checkpointer, options)
async_checkpoint_manager.wait_until_finished()

多主机/多进程检查点#

JAX 提供了几种方法可以在多个主机上同时扩展代码。当设备(CPU/GPU/TPU)数量非常大,以至于不同的设备由不同的主机(CPU)管理时,通常会发生这种情况。要开始使用 JAX 多进程设置,请查看 在多主机和多进程环境中使用 JAX分布式数组指南

在 JAX jit单程序多数据 (SPMD) 范式中,一个大型多进程数组可以将其数据跨不同的设备进行分片。(注意,JAX pjitjit 已合并到一个统一的接口中。要了解如何在多主机或多核环境中编译和执行 JAX 函数,请参考 本指南jax.Array 迁移指南。) 当多进程数组被序列化时,每个主机将其数据分片转储到一个共享存储中,例如 Google Cloud 存储桶。

Orbax 支持以与单进程 pytree 相同的方式保存和加载具有多进程数组的 pytree。但是,建议使用异步的 orbax.AsyncCheckpointer 在另一个线程中保存大型多进程数组,这样您就可以在保存的同时执行计算。使用纯 Orbax,在多进程上下文中保存检查点使用与单进程上下文中相同的 API。

from jax.sharding import PartitionSpec, NamedSharding

# Create an array sharded across multiple devices.
mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, ('x', 'y'))

mp_array = jax.device_put(np.arange(8 * 2).reshape(8, 2),
                          NamedSharding(mesh, PartitionSpec('x', 'y')))

# Make it a pytree.
mp_ckpt = {'model': mp_array}
async_checkpoint_manager.save(0, mp_ckpt)
async_checkpoint_manager.wait_until_finished()

在使用多进程数组恢复检查点时,您需要指定每个数组应该恢复到哪个 sharding。否则,它们将被恢复为进程 0 上的大型 np.array,从而浪费时间和内存。

(在本笔记本中,由于我们使用的是单进程,因此即使我们提供分片,它也将被恢复为 np.array。)

使用 Orbax#

Orbax 允许您通过在 restore_args 中传递 sharding 的 pytree 来指定这一点。如果您已经拥有一个具有所有具有正确分片的数组的引用 pytree,则可以使用 orbax_utils.restore_args_from_target 将其转换为 Orbax 所需的 restore_args

# The reference doesn't need to be as large as your checkpoint!
# Just make sure it has the `.sharding` you want.
mp_smaller = jax.device_put(np.arange(8).reshape(4, 2),
                            NamedSharding(mesh, PartitionSpec('x', 'y')))
ref_ckpt = {'model': mp_smaller}

restore_args = orbax_utils.restore_args_from_target(ref_ckpt)
async_checkpoint_manager.restore(
    0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args})
{'model': Array([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]], dtype=int32)}

使用遗留的 Flax:使用 save_checkpoint_multiprocess#

在遗留的 Flax 中,要保存多进程数组,请使用 flax.training.checkpoints.save_checkpoint_multiprocess() 来代替 save_checkpoint(),并使用相同的参数。

如果您的检查点太大,您可以在管理器中指定 timeout_secs 并为其提供更多时间来完成写入。

async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)
checkpoints.save_checkpoint_multiprocess(ckpt_dir,
                                         mp_ckpt,
                                         step=3,
                                         overwrite=True,
                                         keep=4,
                                         orbax_checkpointer=async_checkpointer)
'/tmp/flax_ckpt/checkpoint_3'
mp_restored = checkpoints.restore_checkpoint(ckpt_dir,
                                             target=ref_ckpt,
                                             step=3,
                                             orbax_checkpointer=async_checkpointer)
mp_restored
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': Array([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]], dtype=int32)}

Orbax 作为后端故障排除#

作为从遗留的 Flax checkpoints API 迁移到 Orbax 的一个中间阶段,flax.training.checkpoints API 将从 2023 年 5 月 15 日开始使用 Orbax 作为其后端来保存检查点。

使用 Orbax 后端保存的检查点可以通过 flax.training.checkpoints.restore_checkpointorbax.checkpoint.PyTreeCheckpointer 读取。

从代码的角度来看,这相当于将配置标志 flax.config.flax_use_orbax_checkpointing 的默认值设置为 True。您可以在项目中随时使用 flax.config.update('flax_use_orbax_checkpointing', <BoolValue>) 覆盖此值。

总的来说,这种自动迁移不会影响大多数用户。但是,如果您的 API 使用方式遵循某些特定的模式,您可能会遇到问题。请查看下面的部分以进行故障排除。

如果您的设备在写入检查点时挂起#

如果您在多主机环境中运行(通常大于 8 个 TPU 设备),并且您的设备在写入检查点时挂起,请检查您的代码是否符合以下模式(即,save_checkpoint 仅在主机 0 上运行)

if jax.process_index() == 0:
  flax.training.checkpoints.save_checkpoint(...)

不幸的是,这是一种将被弃用的遗留模式,不再支持,因为在多进程环境中,检查点代码应该在主机之间进行协调,而不是仅在主机 0 上触发。将上面的代码替换为以下代码应该可以解决挂起问题

flax.training.checkpoints.save_checkpoint_multiprocess(...)

如果您没有保存 pytree#

Orbax 使用 orbax.checkpoint.PyTreeCheckpointHandler 来保存检查点,这意味着它们只保存 pytree。

如果您要保存单个数组或数字,则有两种选择

  1. 使用 orbax.ArrayCheckpointHandler 按照 本迁移部分 中的说明保存它们。

  2. 将其包装在一个 pytree 中并像往常一样保存。