保存和加载检查点#

本指南演示如何使用 Orbax 保存和加载 Flax NNX 模型检查点。

注意:Flax 团队不积极维护用于保存和加载模型检查点到磁盘的库。 因此,建议您使用像 Orbax 这样的外部库来执行此操作。

在本指南中,您将学习如何

  • 保存检查点。

  • 恢复检查点。

  • 如果检查点结构不同,则恢复检查点。

  • 执行多进程检查点。

本指南中使用的 Orbax API 示例仅用于演示目的,有关最新的推荐 API,请参阅 Orbax 网站

注意:Flax 团队建议使用 Orbax 来保存和加载检查点到磁盘,因为我们不积极维护这些功能的库。

注意:如果您正在寻找 Flax Linen 的旧版 flax.training.checkpoints 包,它在 2023 年已弃用,转而使用 Orbax。 文档位于 Flax Linen 站点

设置#

导入必要的依赖项,设置检查点目录和一个示例 Flax NNX 模型 - TwoLayerMLP - 通过继承 nnx.Module

from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np

ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
class TwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

# Instantiate the model and show we can run it.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)

保存检查点#

JAX 检查点库(如 Orbax)可以保存和加载任何给定的 JAX pytree,它是一个纯的、可能嵌套的 jax.Array(或某些其他框架中的“张量”)的容器。 在机器学习的背景下,检查点通常是模型参数和其他数据(例如优化器状态)的 pytree。

在 Flax NNX 中,您可以通过调用 nnx.split,并获取返回的 nnx.State,从 nnx.Module 中获取这样的 pytree。

_, state = nnx.split(model)
nnx.display(state)

checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / 'state', state)

恢复检查点#

请注意,您将检查点保存为 nnx.State 的 Flax 类,它也嵌套了 nnx.VariableStatennx.Param 类。

在检查点恢复时,您需要在运行时准备好这些类,并指示检查点库 (Orbax) 将您的 pytree 恢复到该结构。 这可以通过以下方式实现

  • 首先,创建一个抽象的 Flax NNX 模型(不为数组分配任何内存),并向检查点库显示其抽象变量状态。

  • 获得状态后,使用 nnx.merge 来获取您的 Flax NNX 模型,并像往常一样使用它。

# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
print('The abstract NNX state (all leaves are abstract arrays):')
nnx.display(abstract_state)

state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)
jax.tree.map(np.testing.assert_array_equal, state, state_restored)
print('NNX State restored: ')
nnx.display(state_restored)

# The model is now good to use!
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)
The abstract NNX state (all leaves are abstract arrays):
NNX State restored: 
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1136: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(
The abstract NNX state (all leaves are abstract arrays):
NNX State restored: 


/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(

以纯字典形式保存和恢复#

在与检查点库(如 Orbax)交互时,您可能更喜欢使用 Python 内置的容器类型。 在这种情况下,您可以使用 nnx.State.to_pure_dictnnx.State.replace_by_pure_dict API 来回转换 nnx.State 与纯嵌套字典之间的转换。

# Save as pure dict
pure_dict_state = state.to_pure_dict()
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)

# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4)  # The model still works!
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.

当检查点结构不同时恢复#

当您想要加载一些不再与当前模型代码匹配的过时检查点时,将检查点加载为纯嵌套字典的功能可能会派上用场。 请查看下面的简单示例。

如果您将检查点保存为 nnx.State 而不是纯字典,此模式也适用。 请查看 模型手术 指南的 检查点手术部分,其中包含代码示例。 唯一的区别是您需要在调用 nnx.State.replace_by_pure_dict 之前重新处理您的原始字典。

class ModifiedTwoLayerMLP(nnx.Module):
  """A modified version of TwoLayerMLP, which requires bias arrays."""
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True)  # We need bias now!
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True)  # We need bias now!

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

# Accommodate your old checkpoint to the new code.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))
restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))

# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4)  # The new model works!

nnx.display(model.linear1)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.

多进程检查点#

在多主机/多进程环境中,您可能希望将检查点恢复为在多个设备上分片。 请查看 Flax 在多个设备上扩展 指南中的 从检查点加载分片模型 部分,了解如何派生分片 pytree 并使用它来加载您的检查点。

注意: JAX 提供了几种方法来同时在多个主机上扩展您的代码。 当设备(CPU/GPU/TPU)数量如此之大,以至于不同的设备由不同的主机(CPU)管理时,通常会发生这种情况。 请查看 JAX 的 并行编程简介在多主机和多进程环境中使用 JAX分布式数组和自动并行化 以及 使用 shard_map 进行手动并行化

其他检查点功能#

本指南仅使用最简单的 orbax.checkpoint.StandardCheckpointer API 来展示如何在 Flax 模型方面保存和加载。 您可以根据需要随意使用其他工具或库。

此外,请查看 Orbax 网站 以获取其他常用功能,例如