迁移检查点到 Orbax#
本指南展示了如何将 Flax 的检查点保存和恢复调用(flax.training.checkpoints.save_checkpoint 和 restore_checkpoint)转换为等效的 Orbax 方法。Orbax 为管理各种对象的检查点提供了灵活且可定制的 API。请注意,由于 Flax 的检查点正在从 flax.training.checkpoints
迁移到 Orbax,Flax API 中的所有现有功能将继续受到支持,但 API 将发生变化。
您将通过以下场景学习如何迁移到 Orbax
最常见的用例:保存/加载和管理检查点
一个“轻量级”用例:“纯”保存/加载,没有顶层检查点管理器
在没有目标 pytree 的情况下恢复检查点
异步检查点
保存/加载单个 JAX 或 NumPy 数组
要了解更多关于 Orbax 的信息,请查看 快速入门介绍性 Colab 笔记本 和 官方 Orbax 文档。
您可以点击上面的“在 Colab 中打开”运行本指南中的代码。
在整个指南中,您将能够比较使用和不使用 Orbax 代码的代码示例。
设置#
# Create some dummy variables for this example.
MAX_STEPS = 5
CKPT_PYTREE = [12, {'bar': np.array((2, 3))}, [1, 4, 10]]
TARGET_PYTREE = [0, {'bar': np.array((0))}, [0, 0, 0]]
最常见的用例:保存/加载和管理检查点#
本节介绍以下场景
您原始的 Flax
save_checkpoint()
或save_checkpoint_multiprocess()
调用包含以下参数:prefix
、keep
、keep_every_n_steps
;或者您想对您的检查点使用一些自动管理逻辑(例如,用于删除旧数据,根据指标/损失删除数据,等等)。
在这种情况下,您需要使用 orbax.CheckpointManager
。这不仅允许您保存和加载您的模型,而且还允许您自动管理您的检查点并删除过时的检查点。
要升级您的代码
在顶层创建并保留一个
orbax.CheckpointManager
实例,并使用orbax.CheckpointManagerOptions
进行自定义。在运行时,调用
orbax.CheckpointManager.save()
保存您的数据。然后,调用
orbax.CheckpointManager.restore()
恢复您的数据。而且,如果您的检查点包含一些多主机/多进程数组,请在恢复之前将正确的
mesh
传递给flax.training.orbax_utils.restore_args_from_target()
以生成正确的restore_args
。
例如
CKPT_DIR = '/tmp/orbax_upgrade/'
flax.config.update('flax_use_orbax_checkpointing', False)
# Inside your training loop
for step in range(MAX_STEPS):
# do training
checkpoints.save_checkpoint(CKPT_DIR, CKPT_PYTREE, step=step,
prefix='test_', keep=3, keep_every_n_steps=2)
checkpoints.restore_checkpoint(CKPT_DIR, target=TARGET_PYTREE, step=4, prefix='test_')
CKPT_DIR = '/tmp/orbax_upgrade/orbax'
# At the top level
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
create=True, max_to_keep=3, keep_period=2, step_prefix='test')
ckpt_mgr = orbax.checkpoint.CheckpointManager(
CKPT_DIR,
orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)
# Inside your training loop
for step in range(MAX_STEPS):
# do training
save_args = flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE)
ckpt_mgr.save(step, CKPT_PYTREE, save_kwargs={'save_args': save_args})
restore_args = flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None)
ckpt_mgr.restore(4, items=TARGET_PYTREE, restore_kwargs={'restore_args': restore_args})
一个“轻量级”用例:“纯”保存/加载,没有顶层检查点管理器#
如果您不想维护一个顶层检查点管理器,您仍然可以使用 orbax.checkpoint.Checkpointer
保存和恢复任何单个检查点。请注意,这意味着您无法使用所有 Orbax 管理功能。
要迁移到 Orbax 代码,不要在 flax.save_checkpoint()
中使用 overwrite
参数,而是在 orbax.checkpoint.Checkpointer.save()
中使用 force
参数。
例如
PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(PURE_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True)
checkpoints.restore_checkpoint(PURE_CKPT_DIR, target=TARGET_PYTREE)
PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) # A stateless object, can be created on the fly.
ckptr.save(PURE_CKPT_DIR, CKPT_PYTREE,
save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=True)
ckptr.restore(PURE_CKPT_DIR, item=TARGET_PYTREE,
restore_args=flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None))
在没有目标 pytree 的情况下恢复检查点#
如果您需要在没有目标 pytree 的情况下恢复检查点,请将 item=None
传递给 orbax.checkpoint.Checkpointer
或将 items=None
传递给 orbax.CheckpointManager
的 .restore()
方法,这应该会触发恢复。
例如
NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(NOTARGET_CKPT_DIR, CKPT_PYTREE, step=0)
checkpoints.restore_checkpoint(NOTARGET_CKPT_DIR, target=None)
NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target'
# A stateless object, can be created on the fly.
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
ckptr.save(NOTARGET_CKPT_DIR, CKPT_PYTREE,
save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE))
ckptr.restore(NOTARGET_CKPT_DIR, item=None)
异步检查点#
要使您的检查点保存异步,请用 orbax.checkpoint.AsyncCheckpointer
替换 orbax.checkpoint.Checkpointer
。
然后,您可以调用 orbax.checkpoint.AsyncCheckpointer.wait_until_finished()
或 Orbax 的 CheckpointerManager.wait_until_finished()
来等待保存完成。
有关更多详细信息,请阅读 检查点指南。
您也可以通过异步管理器使用 Orbax AsyncCheckpointer 与 Flax API。异步管理器在内部调用 wait_until_finished()。此解决方案未积极维护,建议使用 Orbax 异步检查点。
例如
ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'
flax.config.update('flax_use_orbax_checkpointing', True)
async_manager = checkpoints.AsyncManager()
checkpoints.save_checkpoint(ASYNC_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True, async_manager=async_manager)
checkpoints.restore_checkpoint(ASYNC_CKPT_DIR, target=TARGET_PYTREE)
ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'
import orbax.checkpoint as ocp
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(ASYNC_CKPT_DIR, args=ocp.args.StandardSave(CKPT_PYTREE))
# ... Continue with your work...
# ... Until a time when you want to wait until the save completes:
ckptr.wait_until_finished() # Blocks until the checkpoint saving is completed.
ckptr.restore(ASYNC_CKPT_DIR, args=ocp.args.StandardRestore(TARGET_PYTREE))
保存/加载单个 JAX 或 NumPy 数组#
顾名思义,orbax.checkpoint.PyTreeCheckpointHandler
类只能用于 pytree。因此,如果您需要保存/恢复单个 pytree 叶子(例如,数组),请使用 orbax.checkpoint.ArrayCheckpointHandler
。
例如
ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(ARR_CKPT_DIR, jnp.arange(10), step=0)
checkpoints.restore_checkpoint(ARR_CKPT_DIR, target=None)
ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton'
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.ArrayCheckpointHandler())
ckptr.save(ARR_CKPT_DIR, jnp.arange(10))
ckptr.restore(ARR_CKPT_DIR, item=None)
最后的话#
本指南概述了如何从“传统”Flax 检查点 API 迁移到 Orbax API。Orbax 提供了更多功能,Orbax 团队正在积极开发新功能。敬请关注,并关注 官方 Orbax GitHub 存储库 获取更多信息!