flax.training 包#

检查点#

检查点辅助函数。

处理根据步数或其他数值指标在文件名中保存和恢复优化器检查点。清理旧的/性能较差的检查点文件。

flax.training.checkpoints.save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, orbax_checkpointer=None)[source]#

保存模型的检查点。适用于单主机。

在此方法中,每个 JAX 进程都会单独保存检查点。如果您有多个进程,并且您希望它们将数据保存到公共目录(例如 GCloud 存储桶),请不要使用它。要将多进程检查点保存到共享存储或保存 GlobalDeviceArray``s, 使用 ``save_checkpoint_multiprocess() 代替。

通过写入临时文件来确保抢占安全,然后进行最终重命名并清理过去的文件。但是,如果使用 async_manager,则最终提交将在异步回调中发生,可以通过调用 async_manager.wait_previous_save() 显式等待。

示例用法

>>> from flax.training import checkpoints
>>> import jax.numpy as jnp
>>> import tempfile

>>> with tempfile.TemporaryDirectory() as dir_path:
...   test_object = {
...     'a': jnp.array([1, 2, 3], jnp.int32),
...     'b': jnp.array([1, 1, 1], jnp.int32),
...   }
...   file_path = checkpoints.save_checkpoint(
...     dir_path, target=test_object, step=0, prefix='test_', keep=1
...   )
...   restored_object = checkpoints.restore_checkpoint(
...     file_path, target=None
...   )
>>> restored_object
{'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
参数
  • ckpt_dir – str 或 pathlib 类路径,用于存储检查点文件。

  • target – 可序列化的 flax 对象,通常是 flax 优化器。

  • step – int 或 float:训练步数或其他指标数字。

  • prefix – str:检查点文件名前缀。

  • keep – 要保留的过去检查点文件数量。

  • overwrite – 如果当前步数或更高步数的检查点已存在,则覆盖现有的检查点文件(默认值:False)。

  • keep_every_n_steps – 如果定义,则每隔 n 步保留一个检查点(除了保留最后 “keep” 个检查点)。

  • async_manager – 如果定义,则保存将运行而不会阻塞主线程。仅适用于单主机。请注意,正在进行的保存仍将阻塞后续保存,以确保覆盖/保留逻辑正常运行。

  • orbax_checkpointer – 如果定义,则保存将由 ocp 完成。将来,所有 Flax 检查点功能都将迁移到 Orbax,建议开始使用 orbax_checkpointer。请查看检查点指南 (https://flax.org.cn/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) 以了解如何使用 Orbax 检查点器。

返回值

已保存检查点的文件名。

flax.training.checkpoints.save_checkpoint_multiprocess(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, gda_manager=None, orbax_checkpointer=None)[source]#

在多进程环境中保存模型的检查点。

使用此方法保存 ``GlobalDeviceArray``s,或将数据保存到公共目录。只有进程 0 会保存主要的检查点文件并删除旧的检查点文件。

通过写入临时文件来确保抢占安全,然后进行最终重命名并清理过去的文件。但是,如果使用 async_manager 或 gda_manager,则最终提交将在异步回调中发生,可以通过调用 async_manager.wait_previous_save()gda_manager.wait_until_finished() 显式等待。

参数
  • ckpt_dir – str 或 pathlib 类路径,用于存储检查点文件。

  • target – 可序列化的 flax 对象,通常是 flax 优化器。

  • step – int 或 float:训练步数或其他指标数字。

  • prefix – str:检查点文件名前缀。

  • keep – 要保留的过去检查点文件数量。

  • overwrite – 如果当前步数或更高步数的检查点已存在,则覆盖现有的检查点文件(默认值:False)。

  • keep_every_n_steps – 如果定义,则每隔 n 步保留一个检查点(除了保留最后 “keep” 个检查点)。

  • async_manager – 如果定义,则保存将运行而不会阻塞主线程。仅适用于单主机。请注意,正在进行的保存仍将阻塞后续保存,以确保覆盖/保留逻辑正常运行。

  • gda_manager – 如果目标包含 JAX GlobalDeviceArray,则需要。会将 GDA 异步保存到具有后缀 “_gda” 的单独子目录。与 async_manager 一样,这将阻塞后续保存。

  • orbax_checkpointer – 如果定义,则保存将由 Orbax 完成。将来,所有 Flax 检查点功能都将迁移到 Orbax,建议开始使用 orbax_checkpointer。请查看检查点指南 (https://flax.org.cn/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) 以了解如何使用 Orbax 检查点器。

返回值

已保存检查点的文件名。

flax.training.checkpoints.latest_checkpoint(ckpt_dir, prefix='checkpoint_')[source]#

检索目录中最新检查点的路径。

参数
  • ckpt_dir – str:要从中恢复的检查点目录。

  • prefix – str:检查点文件名的名称前缀。

返回值

最新检查点路径,如果没有找到检查点,则为 None。

flax.training.checkpoints.restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True, gda_manager=None, allow_partial_mpa_restoration=False, orbax_checkpointer=None, orbax_transforms=None)[source]#

从路径中的检查点恢复最后一个/最佳检查点。

对检查点文件进行自然排序,返回值最大的文件,例如:

  • ckpt_1, ckpt_2, ckpt_3 --> ckpt_3

  • ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1

  • ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

示例用法

>>> from flax.training import checkpoints
>>> import jax.numpy as jnp
>>> import tempfile
...
>>> with tempfile.TemporaryDirectory() as dir_path:
...   test_object = {
...     'a': jnp.array([1, 2, 3], jnp.int32),
...     'b': jnp.array([1, 1, 1], jnp.int32),
...   }
...   file_path = checkpoints.save_checkpoint(
...     dir_path, target=test_object, step=0, prefix='test_', keep=1
...   )
...   restored_object = checkpoints.restore_checkpoint(
...     file_path, target=None
...   )
>>> restored_object
{'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
参数
  • ckpt_dir – str: 要从中恢复的检查点文件或检查点目录。

  • target – 通过反序列化状态字典重建的匹配对象。如果为 None,则反序列化状态字典将按原样返回。

  • step – int 或 float:要加载的步数,或 None 以加载最新步数。如果指定,则 ckpt_dir 必须是目录。

  • prefix – str:检查点文件名的名称前缀。

  • parallel – bool: 是否并行加载可寻求的检查点,以提高速度。

  • gda_manager – 检查点包含多进程数组(GlobalDeviceArray 或来自 pjit 的 jax 数组)时需要。将从带有后缀“_gda”的单独子目录中读取数组。

  • allow_partial_mpa_restoration – 如果为 True,则给定的 target 不必包含所有有效的多进程数组。因此,恢复的 Pytree 可能包含一些未正确恢复的 MPA。如果您无法提供完全有效的 target 并且不需要恢复检查点中的所有 MPA,请使用此选项。

  • orbax_checkpointer – 如果给定检查点使用 ocp 保存,则处理底层恢复的 ocp.Checkpointer

  • orbax_transforms – 将传递到 orbax_checkpointer.restore() 调用的 Orbax 转换。

返回值

从检查点文件更新的已恢复 target,或者如果未指定步数且不存在检查点文件,则返回未更改的传入 target。如果指定了文件路径但未找到,则将返回传入的 target。这与指定目录路径但目录尚未创建的情况的行为相匹配。

flax.training.checkpoints.convert_pre_linen(params)[source]#

转换 pre-Linen 参数 pytree。

在 pre-Linen API 中,子模块是按顺序编号的,与子模块类无关。在 Linen 中,此行为已更改,以在每个模块类中保持单独的子模块计数。

考虑以下模块

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(1, 1)(x)
    x = nn.Dense(1)(x)
    return x

在 pre-Linen 中,生成的 params 将具有以下结构

{'Conv_0': { ... }, 'Dense_1': { ... } }

在 Linen 中,生成的 params 将具有以下结构

{'Conv_0': { ... }, 'Dense_0': { ... } }

要从 pre-Linen 格式转换为 Linen,只需调用

params = convert_pre_linen(pre_linen_params)

请注意,您也可以使用此实用程序来转换 pre-Linen 集合,因为它们遵循相同的模块命名。但是请注意,集合在 pre-Linen 中是“扁平的”,需要先展开才能使用此函数

batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({
    tuple(k.split('/')[1:]): v
    for k, v in pre_linen_model_state.as_dict().items()
}))

然后可以使用这些转换后的集合定义 Linen 变量

variables = {'params': params, 'batch_stats': batch_stats}
参数

params – pre-Linen 格式的参数 pytree。如果 pytree 已经是 Linen 格式,则返回的 pytree 将保持不变(即,此函数可以安全地调用任何加载的检查点以用于 Linen)。

返回值

具有 Linen 子模块命名的参数 pytree。

学习率计划#

FLAX 图像分类示例中使用的学习率计划。

请注意,随着 FLIP #1009flax.training 中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划

flax.training.lr_schedule.create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, warmup_length=0.0)[source]#

创建具有可选预热的常数学习率计划。

请注意,随着 FLIP #1009flax.training 中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划

保持学习率恒定。此函数还提供学习率预热,如 https://arxiv.org/abs/1706.02677 所述,用于使用大批量进行训练。

参数
  • base_learning_rate – 基本学习率

  • steps_per_epoch – 每个 epoch 的迭代次数

  • warmup_length – 如果 > 0,则学习率将通过预热因子进行调制,该因子将在前 warmup_length 个 epoch 中从 0 线性递增到 1

返回值

函数 f(step) -> lr,用于计算给定步数的学习率。

flax.training.lr_schedule.create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0)[source]#

创建具有可选预热的阶梯式学习率计划。

请注意,随着 FLIP #1009flax.training 中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划

阶梯式学习率计划在指定的 epoch 中以指定的量减少学习率。这些步骤在 lr_sched_steps 参数中给出。一个常见的 ImageNet 计划在 epoch 30、60 和 80 处将学习率降低 0.1 倍。这将指定为

[
  [30, 0.1],
  [60, 0.01],
  [80, 0.001]
]

此函数还提供学习率预热,如 https://arxiv.org/abs/1706.02677 所述,用于使用大批量进行训练。

参数
  • base_learning_rate – 基本学习率

  • steps_per_epoch – 每个 epoch 的迭代次数

  • lr_sched_steps – 计划,以步骤列表的形式给出,每个步骤都是一个 [epoch, lr_factor] 对;该步骤在 epoch epoch 处发生,并将学习率设置为 base_learning_rage * lr_factor

  • warmup_length – 如果 > 0,则学习率将通过预热因子进行调制,该因子将在前 warmup_length 个 epoch 中从 0 线性递增到 1

返回值

函数 f(step) -> lr,用于计算给定步数的学习率。

flax.training.lr_schedule.create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0)[source]#

创建具有可选预热的余弦学习率计划。

请注意,随着 FLIP #1009flax.training 中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划

余弦学习率计划使用半个余弦波来调制学习率,在训练结束时逐渐将其缩放到 0。

此函数还提供学习率预热,如 https://arxiv.org/abs/1706.02677 所述,用于使用大批量进行训练。

参数
  • base_learning_rate – 基本学习率

  • steps_per_epoch – 每个 epoch 的迭代次数

  • halfcos_epochs – 完成半个余弦波的 epoch 数;通常是用于训练的 epoch 数

  • warmup_length – 如果 > 0,则学习率将通过预热因子进行调制,该因子将在前 warmup_length 个 epoch 中从 0 线性递增到 1

返回值

函数 f(step) -> lr,用于计算给定步数的学习率。

训练状态#

class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#

适用于使用单个 Optax 优化器的常见情况的简单训练状态。

示例用法

>>> import flax.linen as nn
>>> from flax.training.train_state import TrainState
>>> import jax, jax.numpy as jnp
>>> import optax

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 2))
>>> model = nn.Dense(2)
>>> variables = model.init(jax.random.key(0), x)
>>> tx = optax.adam(1e-3)

>>> state = TrainState.create(
...     apply_fn=model.apply,
...     params=variables['params'],
...     tx=tx)

>>> def loss_fn(params, x, y):
...   predictions = state.apply_fn({'params': params}, x)
...   loss = optax.l2_loss(predictions=predictions, targets=y).mean()
...   return loss
>>> loss_fn(state.params, x, y)
Array(3.3514676, dtype=float32)

>>> grads = jax.grad(loss_fn)(state.params, x, y)
>>> state = state.apply_gradients(grads=grads)
>>> loss_fn(state.params, x, y)
Array(3.343844, dtype=float32)

请注意,您可以通过子类化此数据类来轻松扩展它,以存储其他数据(例如,其他变量集合)。

对于更奇特的用例(例如,多个优化器),最好克隆该类并对其进行修改。

参数
  • step – 计数器从 0 开始,每次调用 .apply_gradients() 时都会递增。

  • apply_fn – 通常设置为 model.apply()。保留在此数据类中,以便为训练循环中的 train_step() 函数提供更短的参数列表。

  • params – 由 tx 更新并由 apply_fn 使用的参数。

  • tx – Optax 梯度转换。

  • opt_statetx 的状态。

apply_gradients(*, grads, **kwargs)[source]#

更新返回值中的 stepparamsopt_state**kwargs

请注意,在内部,此函数会调用 .tx.update(),然后调用 optax.apply_updates() 来更新 paramsopt_state

参数
  • grads – 与 .params 具有相同 pytree 结构的梯度。

  • **kwargs – 应该被 .replace() 的其他数据类属性。

返回值

一个更新后的 self 实例,其中 step 增加 1,paramsopt_state 通过应用 grads 更新,并且其他属性根据 kwargs 中指定的替换。

classmethod create(*, apply_fn, params, tx, **kwargs)[source]#

创建一个新的实例,其中 step=0 并且 opt_state 初始化。

提前停止#

class flax.training.early_stopping.EarlyStopping(min_delta=0, patience=0, best_metric=inf, patience_count=0, should_stop=False, has_improved=False)[source]#

在训练期间提前停止以避免过度拟合。

以下示例如果当前 epoch 和前一个 epoch 记录的损失之间的差小于 1e-3 连续 2 次,则提前停止训练

>>> from flax.training.early_stopping import EarlyStopping

>>> def train_epoch(optimizer, train_ds, batch_size, epoch, input_rng):
...   ...
...   loss = [4, 3, 3, 3, 2, 2, 2, 2, 1, 1][epoch]
...   return None, {'loss': loss}

>>> early_stop = EarlyStopping(min_delta=1e-3, patience=2)
>>> optimizer = None
>>> for epoch in range(10):
...   optimizer, train_metrics = train_epoch(
...       optimizer=optimizer, train_ds=None, batch_size=None, epoch=epoch, input_rng=None)
...   early_stop = early_stop.update(train_metrics['loss'])
...   if early_stop.should_stop:
...     print(f'Met early stopping criteria, breaking at epoch {epoch}')
...     break
Met early stopping criteria, breaking at epoch 7
min_delta#

要被视为改进的更新之间的最小增量。

类型

float

patience#

停止之前没有改进的步骤数。

类型

int

best_metric#

当前最佳指标值。

类型

float

patience_count#

自上次改进更新以来的步骤数。

类型

int

should_stop#

训练循环是否应该停止以避免过度拟合。

类型

bool

has_improved#

指标是否在最后一次 .update 调用中比 min_delta 大或等于 min_delta

类型

bool

update(metric)[source]#

根据指标更新状态。

返回值

更新后的 EarlyStopping 类。当与前一个 best_metric 相比有大于 min_delta 的改进时,.has_improved 属性为 True。

通用工具#

flax.training.common_utils.shard(xs)[source]#

pmap 的辅助函数,用于通过 local_device_count 对数组的 pytree 进行分片。

参数

xs – 数组的 pytree。

返回值

与之匹配的 pytree,其中数组的领先维度通过本地设备数量进行分片。

flax.training.common_utils.shard_prng_key(prng_key)[source]#

用于对 PRNGKey 进行分片(也称为拆分)的辅助函数,以便与 pmap 的函数一起使用。

PRNG 密钥可以在训练时用于驱动随机模块,例如 Dropout。我们希望每个本地设备都有一个不同的 PRNG 密钥,以便我们在每个设备上获得不同的随机数,因此我们拆分了 PRNG 密钥。

参数

prng_key – JAX PRNGKey

返回值

一个新的 PRNGKeys 数组,其领先维度等于本地设备数量。

flax.training.common_utils.stack_forest(forest)[source]#

堆叠一系列 pytree 的叶子的辅助函数。

参数

forest – 匹配结构的一系列 pytree(例如元组或列表),其叶子是具有单独匹配形状的数组。

返回值

单个具有相同结构的 pytree,其叶子是单独

堆叠的数组。

flax.training.common_utils.get_metrics(device_metrics)[source]#

pmap 的辅助工具,用于收集复制的时间序列指标数据。

参数

device_metrics – 复制的、设备驻留的指标数据的 pytree,其叶子被假定为随着时间推移记录的数组序列。

返回值

一个未复制的、主机驻留的、随时间推移堆叠的数组的 pytree,可用于计算主机本地统计信息和日志记录。

flax.training.common_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#

创建索引数组的密集单热版本。

注意:考虑使用更标准的 jax.nn.one_hot 代替。

参数
  • labels – 一个 n 维 JAX 数组,其最后维度包含整数索引。

  • num_classes – 最大可能的索引。

  • on_value – 单热数组的“on”值,默认为 1.0。

  • off_value – 单热数组的“off”值,默认为 0.0。

返回值

一个 (n+1) 维数组,其最后维度包含长度为 num_classes 的单热向量。