迁移检查点到 Orbax#

本指南展示了如何将 Flax 的检查点保存和恢复调用(flax.training.checkpoints.save_checkpointrestore_checkpoint)转换为等效的 Orbax 方法。Orbax 为管理各种对象的检查点提供了灵活且可定制的 API。请注意,由于 Flax 的检查点正在从 flax.training.checkpoints 迁移到 Orbax,Flax API 中的所有现有功能将继续受到支持,但 API 将发生变化。

您将通过以下场景学习如何迁移到 Orbax

  • 最常见的用例:保存/加载和管理检查点

  • 一个“轻量级”用例:“纯”保存/加载,没有顶层检查点管理器

  • 在没有目标 pytree 的情况下恢复检查点

  • 异步检查点

  • 保存/加载单个 JAX 或 NumPy 数组

要了解更多关于 Orbax 的信息,请查看 快速入门介绍性 Colab 笔记本官方 Orbax 文档

您可以点击上面的“在 Colab 中打开”运行本指南中的代码。

在整个指南中,您将能够比较使用和不使用 Orbax 代码的代码示例。

设置#

# Create some dummy variables for this example.
MAX_STEPS = 5
CKPT_PYTREE = [12, {'bar': np.array((2, 3))}, [1, 4, 10]]
TARGET_PYTREE = [0, {'bar': np.array((0))}, [0, 0, 0]]

最常见的用例:保存/加载和管理检查点#

本节介绍以下场景

  • 您原始的 Flax save_checkpoint()save_checkpoint_multiprocess() 调用包含以下参数:prefixkeepkeep_every_n_steps;或者

  • 您想对您的检查点使用一些自动管理逻辑(例如,用于删除旧数据,根据指标/损失删除数据,等等)。

在这种情况下,您需要使用 orbax.CheckpointManager。这不仅允许您保存和加载您的模型,而且还允许您自动管理您的检查点并删除过时的检查点。

要升级您的代码

  1. 在顶层创建并保留一个 orbax.CheckpointManager 实例,并使用 orbax.CheckpointManagerOptions 进行自定义。

  2. 在运行时,调用 orbax.CheckpointManager.save() 保存您的数据。

  3. 然后,调用 orbax.CheckpointManager.restore() 恢复您的数据。

  4. 而且,如果您的检查点包含一些多主机/多进程数组,请在恢复之前将正确的 mesh 传递给 flax.training.orbax_utils.restore_args_from_target() 以生成正确的 restore_args

例如

CKPT_DIR = '/tmp/orbax_upgrade/'
flax.config.update('flax_use_orbax_checkpointing', False)

# Inside your training loop
for step in range(MAX_STEPS):
  # do training
  checkpoints.save_checkpoint(CKPT_DIR, CKPT_PYTREE, step=step,
                              prefix='test_', keep=3, keep_every_n_steps=2)


checkpoints.restore_checkpoint(CKPT_DIR, target=TARGET_PYTREE, step=4, prefix='test_')
CKPT_DIR = '/tmp/orbax_upgrade/orbax'

# At the top level
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
  create=True, max_to_keep=3, keep_period=2, step_prefix='test')
ckpt_mgr = orbax.checkpoint.CheckpointManager(
  CKPT_DIR,
  orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)

# Inside your training loop
for step in range(MAX_STEPS):
  # do training
  save_args = flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE)
  ckpt_mgr.save(step, CKPT_PYTREE, save_kwargs={'save_args': save_args})


restore_args = flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None)
ckpt_mgr.restore(4, items=TARGET_PYTREE, restore_kwargs={'restore_args': restore_args})

一个“轻量级”用例:“纯”保存/加载,没有顶层检查点管理器#

如果您不想维护一个顶层检查点管理器,您仍然可以使用 orbax.checkpoint.Checkpointer 保存和恢复任何单个检查点。请注意,这意味着您无法使用所有 Orbax 管理功能。

要迁移到 Orbax 代码,不要在 flax.save_checkpoint() 中使用 overwrite 参数,而是在 orbax.checkpoint.Checkpointer.save() 中使用 force 参数。

例如

PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'
flax.config.update('flax_use_orbax_checkpointing', False)

checkpoints.save_checkpoint(PURE_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True)
checkpoints.restore_checkpoint(PURE_CKPT_DIR, target=TARGET_PYTREE)
PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'

ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())  # A stateless object, can be created on the fly.
ckptr.save(PURE_CKPT_DIR, CKPT_PYTREE,
           save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=True)
ckptr.restore(PURE_CKPT_DIR, item=TARGET_PYTREE,
              restore_args=flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None))

在没有目标 pytree 的情况下恢复检查点#

如果您需要在没有目标 pytree 的情况下恢复检查点,请将 item=None 传递给 orbax.checkpoint.Checkpointer 或将 items=None 传递给 orbax.CheckpointManager.restore() 方法,这应该会触发恢复。

例如

NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'
flax.config.update('flax_use_orbax_checkpointing', False)

checkpoints.save_checkpoint(NOTARGET_CKPT_DIR, CKPT_PYTREE, step=0)
checkpoints.restore_checkpoint(NOTARGET_CKPT_DIR, target=None)
NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'

# A stateless object, can be created on the fly.
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
ckptr.save(NOTARGET_CKPT_DIR, CKPT_PYTREE,
           save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE))
ckptr.restore(NOTARGET_CKPT_DIR, item=None)

异步检查点#

要使您的检查点保存异步,请用 orbax.checkpoint.AsyncCheckpointer 替换 orbax.checkpoint.Checkpointer

然后,您可以调用 orbax.checkpoint.AsyncCheckpointer.wait_until_finished() 或 Orbax 的 CheckpointerManager.wait_until_finished() 来等待保存完成。

有关更多详细信息,请阅读 检查点指南

您也可以通过异步管理器使用 Orbax AsyncCheckpointer 与 Flax API。异步管理器在内部调用 wait_until_finished()。此解决方案未积极维护,建议使用 Orbax 异步检查点。

例如

ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'
flax.config.update('flax_use_orbax_checkpointing', True)
async_manager = checkpoints.AsyncManager()

checkpoints.save_checkpoint(ASYNC_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True, async_manager=async_manager)
checkpoints.restore_checkpoint(ASYNC_CKPT_DIR, target=TARGET_PYTREE)
ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'

import orbax.checkpoint as ocp
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(ASYNC_CKPT_DIR, args=ocp.args.StandardSave(CKPT_PYTREE))
# ... Continue with your work...
# ... Until a time when you want to wait until the save completes:
ckptr.wait_until_finished() # Blocks until the checkpoint saving is completed.
ckptr.restore(ASYNC_CKPT_DIR, args=ocp.args.StandardRestore(TARGET_PYTREE))

保存/加载单个 JAX 或 NumPy 数组#

顾名思义,orbax.checkpoint.PyTreeCheckpointHandler 类只能用于 pytree。因此,如果您需要保存/恢复单个 pytree 叶子(例如,数组),请使用 orbax.checkpoint.ArrayCheckpointHandler

例如

ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'
flax.config.update('flax_use_orbax_checkpointing', False)

checkpoints.save_checkpoint(ARR_CKPT_DIR, jnp.arange(10), step=0)
checkpoints.restore_checkpoint(ARR_CKPT_DIR, target=None)
ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'

ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.ArrayCheckpointHandler())
ckptr.save(ARR_CKPT_DIR, jnp.arange(10))
ckptr.restore(ARR_CKPT_DIR, item=None)

最后的话#

本指南概述了如何从“传统”Flax 检查点 API 迁移到 Orbax API。Orbax 提供了更多功能,Orbax 团队正在积极开发新功能。敬请关注,并关注 官方 Orbax GitHub 存储库 获取更多信息!