flax.training 包#
检查点#
检查点辅助函数。
处理根据步数或其他数值指标在文件名中保存和恢复优化器检查点。清理旧的/性能较差的检查点文件。
- flax.training.checkpoints.save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, orbax_checkpointer=None)[source]#
保存模型的检查点。适用于单主机。
在此方法中,每个 JAX 进程都会单独保存检查点。如果您有多个进程,并且您希望它们将数据保存到公共目录(例如 GCloud 存储桶),请不要使用它。要将多进程检查点保存到共享存储或保存
GlobalDeviceArray``s, 使用 ``save_checkpoint_multiprocess()
代替。通过写入临时文件来确保抢占安全,然后进行最终重命名并清理过去的文件。但是,如果使用 async_manager,则最终提交将在异步回调中发生,可以通过调用
async_manager.wait_previous_save()
显式等待。示例用法
>>> from flax.training import checkpoints >>> import jax.numpy as jnp >>> import tempfile >>> with tempfile.TemporaryDirectory() as dir_path: ... test_object = { ... 'a': jnp.array([1, 2, 3], jnp.int32), ... 'b': jnp.array([1, 1, 1], jnp.int32), ... } ... file_path = checkpoints.save_checkpoint( ... dir_path, target=test_object, step=0, prefix='test_', keep=1 ... ) ... restored_object = checkpoints.restore_checkpoint( ... file_path, target=None ... ) >>> restored_object {'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
- 参数
ckpt_dir – str 或 pathlib 类路径,用于存储检查点文件。
target – 可序列化的 flax 对象,通常是 flax 优化器。
step – int 或 float:训练步数或其他指标数字。
prefix – str:检查点文件名前缀。
keep – 要保留的过去检查点文件数量。
overwrite – 如果当前步数或更高步数的检查点已存在,则覆盖现有的检查点文件(默认值:False)。
keep_every_n_steps – 如果定义,则每隔 n 步保留一个检查点(除了保留最后 “keep” 个检查点)。
async_manager – 如果定义,则保存将运行而不会阻塞主线程。仅适用于单主机。请注意,正在进行的保存仍将阻塞后续保存,以确保覆盖/保留逻辑正常运行。
orbax_checkpointer – 如果定义,则保存将由 ocp 完成。将来,所有 Flax 检查点功能都将迁移到 Orbax,建议开始使用
orbax_checkpointer
。请查看检查点指南 (https://flax.org.cn/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) 以了解如何使用 Orbax 检查点器。
- 返回值
已保存检查点的文件名。
- flax.training.checkpoints.save_checkpoint_multiprocess(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, gda_manager=None, orbax_checkpointer=None)[source]#
在多进程环境中保存模型的检查点。
使用此方法保存 ``GlobalDeviceArray``s,或将数据保存到公共目录。只有进程 0 会保存主要的检查点文件并删除旧的检查点文件。
通过写入临时文件来确保抢占安全,然后进行最终重命名并清理过去的文件。但是,如果使用 async_manager 或 gda_manager,则最终提交将在异步回调中发生,可以通过调用
async_manager.wait_previous_save()
或gda_manager.wait_until_finished()
显式等待。- 参数
ckpt_dir – str 或 pathlib 类路径,用于存储检查点文件。
target – 可序列化的 flax 对象,通常是 flax 优化器。
step – int 或 float:训练步数或其他指标数字。
prefix – str:检查点文件名前缀。
keep – 要保留的过去检查点文件数量。
overwrite – 如果当前步数或更高步数的检查点已存在,则覆盖现有的检查点文件(默认值:False)。
keep_every_n_steps – 如果定义,则每隔 n 步保留一个检查点(除了保留最后 “keep” 个检查点)。
async_manager – 如果定义,则保存将运行而不会阻塞主线程。仅适用于单主机。请注意,正在进行的保存仍将阻塞后续保存,以确保覆盖/保留逻辑正常运行。
gda_manager – 如果目标包含 JAX GlobalDeviceArray,则需要。会将 GDA 异步保存到具有后缀 “_gda” 的单独子目录。与 async_manager 一样,这将阻塞后续保存。
orbax_checkpointer – 如果定义,则保存将由 Orbax 完成。将来,所有 Flax 检查点功能都将迁移到 Orbax,建议开始使用
orbax_checkpointer
。请查看检查点指南 (https://flax.org.cn/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) 以了解如何使用 Orbax 检查点器。
- 返回值
已保存检查点的文件名。
- flax.training.checkpoints.latest_checkpoint(ckpt_dir, prefix='checkpoint_')[source]#
检索目录中最新检查点的路径。
- 参数
ckpt_dir – str:要从中恢复的检查点目录。
prefix – str:检查点文件名的名称前缀。
- 返回值
最新检查点路径,如果没有找到检查点,则为 None。
- flax.training.checkpoints.restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True, gda_manager=None, allow_partial_mpa_restoration=False, orbax_checkpointer=None, orbax_transforms=None)[source]#
从路径中的检查点恢复最后一个/最佳检查点。
对检查点文件进行自然排序,返回值最大的文件,例如:
ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5
示例用法
>>> from flax.training import checkpoints >>> import jax.numpy as jnp >>> import tempfile ... >>> with tempfile.TemporaryDirectory() as dir_path: ... test_object = { ... 'a': jnp.array([1, 2, 3], jnp.int32), ... 'b': jnp.array([1, 1, 1], jnp.int32), ... } ... file_path = checkpoints.save_checkpoint( ... dir_path, target=test_object, step=0, prefix='test_', keep=1 ... ) ... restored_object = checkpoints.restore_checkpoint( ... file_path, target=None ... ) >>> restored_object {'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
- 参数
ckpt_dir – str: 要从中恢复的检查点文件或检查点目录。
target – 通过反序列化状态字典重建的匹配对象。如果为 None,则反序列化状态字典将按原样返回。
step – int 或 float:要加载的步数,或 None 以加载最新步数。如果指定,则 ckpt_dir 必须是目录。
prefix – str:检查点文件名的名称前缀。
parallel – bool: 是否并行加载可寻求的检查点,以提高速度。
gda_manager – 检查点包含多进程数组(GlobalDeviceArray 或来自 pjit 的 jax 数组)时需要。将从带有后缀“_gda”的单独子目录中读取数组。
allow_partial_mpa_restoration – 如果为 True,则给定的
target
不必包含所有有效的多进程数组。因此,恢复的 Pytree 可能包含一些未正确恢复的 MPA。如果您无法提供完全有效的target
并且不需要恢复检查点中的所有 MPA,请使用此选项。orbax_checkpointer – 如果给定检查点使用 ocp 保存,则处理底层恢复的
ocp.Checkpointer
。orbax_transforms – 将传递到
orbax_checkpointer.restore()
调用的 Orbax 转换。
- 返回值
从检查点文件更新的已恢复
target
,或者如果未指定步数且不存在检查点文件,则返回未更改的传入target
。如果指定了文件路径但未找到,则将返回传入的target
。这与指定目录路径但目录尚未创建的情况的行为相匹配。
- flax.training.checkpoints.convert_pre_linen(params)[source]#
转换 pre-Linen 参数 pytree。
在 pre-Linen API 中,子模块是按顺序编号的,与子模块类无关。在 Linen 中,此行为已更改,以在每个模块类中保持单独的子模块计数。
考虑以下模块
class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(1, 1)(x) x = nn.Dense(1)(x) return x
在 pre-Linen 中,生成的 params 将具有以下结构
{'Conv_0': { ... }, 'Dense_1': { ... } }
在 Linen 中,生成的 params 将具有以下结构
{'Conv_0': { ... }, 'Dense_0': { ... } }
要从 pre-Linen 格式转换为 Linen,只需调用
params = convert_pre_linen(pre_linen_params)
请注意,您也可以使用此实用程序来转换 pre-Linen 集合,因为它们遵循相同的模块命名。但是请注意,集合在 pre-Linen 中是“扁平的”,需要先展开才能使用此函数
batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({ tuple(k.split('/')[1:]): v for k, v in pre_linen_model_state.as_dict().items() }))
然后可以使用这些转换后的集合定义 Linen 变量
variables = {'params': params, 'batch_stats': batch_stats}
- 参数
params – pre-Linen 格式的参数 pytree。如果 pytree 已经是 Linen 格式,则返回的 pytree 将保持不变(即,此函数可以安全地调用任何加载的检查点以用于 Linen)。
- 返回值
具有 Linen 子模块命名的参数 pytree。
学习率计划#
FLAX 图像分类示例中使用的学习率计划。
请注意,随着 FLIP #1009,flax.training
中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划。
- flax.training.lr_schedule.create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, warmup_length=0.0)[source]#
创建具有可选预热的常数学习率计划。
请注意,随着 FLIP #1009,
flax.training
中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划。保持学习率恒定。此函数还提供学习率预热,如 https://arxiv.org/abs/1706.02677 所述,用于使用大批量进行训练。
- 参数
base_learning_rate – 基本学习率
steps_per_epoch – 每个 epoch 的迭代次数
warmup_length – 如果 > 0,则学习率将通过预热因子进行调制,该因子将在前
warmup_length
个 epoch 中从 0 线性递增到 1
- 返回值
函数
f(step) -> lr
,用于计算给定步数的学习率。
- flax.training.lr_schedule.create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0)[source]#
创建具有可选预热的阶梯式学习率计划。
请注意,随着 FLIP #1009,
flax.training
中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划。阶梯式学习率计划在指定的 epoch 中以指定的量减少学习率。这些步骤在
lr_sched_steps
参数中给出。一个常见的 ImageNet 计划在 epoch 30、60 和 80 处将学习率降低 0.1 倍。这将指定为[ [30, 0.1], [60, 0.01], [80, 0.001] ]
此函数还提供学习率预热,如 https://arxiv.org/abs/1706.02677 所述,用于使用大批量进行训练。
- 参数
base_learning_rate – 基本学习率
steps_per_epoch – 每个 epoch 的迭代次数
lr_sched_steps – 计划,以步骤列表的形式给出,每个步骤都是一个
[epoch, lr_factor]
对;该步骤在 epochepoch
处发生,并将学习率设置为base_learning_rage * lr_factor
warmup_length – 如果 > 0,则学习率将通过预热因子进行调制,该因子将在前
warmup_length
个 epoch 中从 0 线性递增到 1
- 返回值
函数
f(step) -> lr
,用于计算给定步数的学习率。
- flax.training.lr_schedule.create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0)[source]#
创建具有可选预热的余弦学习率计划。
请注意,随着 FLIP #1009,
flax.training
中的学习率计划实际上已弃用,转而支持 Optax 计划。有关更多信息,请参考 优化器计划。余弦学习率计划使用半个余弦波来调制学习率,在训练结束时逐渐将其缩放到 0。
此函数还提供学习率预热,如 https://arxiv.org/abs/1706.02677 所述,用于使用大批量进行训练。
- 参数
base_learning_rate – 基本学习率
steps_per_epoch – 每个 epoch 的迭代次数
halfcos_epochs – 完成半个余弦波的 epoch 数;通常是用于训练的 epoch 数
warmup_length – 如果 > 0,则学习率将通过预热因子进行调制,该因子将在前
warmup_length
个 epoch 中从 0 线性递增到 1
- 返回值
函数
f(step) -> lr
,用于计算给定步数的学习率。
训练状态#
- class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#
适用于使用单个 Optax 优化器的常见情况的简单训练状态。
示例用法
>>> import flax.linen as nn >>> from flax.training.train_state import TrainState >>> import jax, jax.numpy as jnp >>> import optax >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 2)) >>> model = nn.Dense(2) >>> variables = model.init(jax.random.key(0), x) >>> tx = optax.adam(1e-3) >>> state = TrainState.create( ... apply_fn=model.apply, ... params=variables['params'], ... tx=tx) >>> def loss_fn(params, x, y): ... predictions = state.apply_fn({'params': params}, x) ... loss = optax.l2_loss(predictions=predictions, targets=y).mean() ... return loss >>> loss_fn(state.params, x, y) Array(3.3514676, dtype=float32) >>> grads = jax.grad(loss_fn)(state.params, x, y) >>> state = state.apply_gradients(grads=grads) >>> loss_fn(state.params, x, y) Array(3.343844, dtype=float32)
请注意,您可以通过子类化此数据类来轻松扩展它,以存储其他数据(例如,其他变量集合)。
对于更奇特的用例(例如,多个优化器),最好克隆该类并对其进行修改。
- 参数
step – 计数器从 0 开始,每次调用
.apply_gradients()
时都会递增。apply_fn – 通常设置为
model.apply()
。保留在此数据类中,以便为训练循环中的train_step()
函数提供更短的参数列表。params – 由
tx
更新并由apply_fn
使用的参数。tx – Optax 梯度转换。
opt_state –
tx
的状态。
- apply_gradients(*, grads, **kwargs)[source]#
更新返回值中的
step
、params
、opt_state
和**kwargs
。请注意,在内部,此函数会调用
.tx.update()
,然后调用optax.apply_updates()
来更新params
和opt_state
。- 参数
grads – 与
.params
具有相同 pytree 结构的梯度。**kwargs – 应该被
.replace()
的其他数据类属性。
- 返回值
一个更新后的
self
实例,其中step
增加 1,params
和opt_state
通过应用grads
更新,并且其他属性根据kwargs
中指定的替换。
提前停止#
- class flax.training.early_stopping.EarlyStopping(min_delta=0, patience=0, best_metric=inf, patience_count=0, should_stop=False, has_improved=False)[source]#
在训练期间提前停止以避免过度拟合。
以下示例如果当前 epoch 和前一个 epoch 记录的损失之间的差小于 1e-3 连续 2 次,则提前停止训练
>>> from flax.training.early_stopping import EarlyStopping >>> def train_epoch(optimizer, train_ds, batch_size, epoch, input_rng): ... ... ... loss = [4, 3, 3, 3, 2, 2, 2, 2, 1, 1][epoch] ... return None, {'loss': loss} >>> early_stop = EarlyStopping(min_delta=1e-3, patience=2) >>> optimizer = None >>> for epoch in range(10): ... optimizer, train_metrics = train_epoch( ... optimizer=optimizer, train_ds=None, batch_size=None, epoch=epoch, input_rng=None) ... early_stop = early_stop.update(train_metrics['loss']) ... if early_stop.should_stop: ... print(f'Met early stopping criteria, breaking at epoch {epoch}') ... break Met early stopping criteria, breaking at epoch 7
- min_delta#
要被视为改进的更新之间的最小增量。
- 类型
float
- patience#
停止之前没有改进的步骤数。
- 类型
int
- best_metric#
当前最佳指标值。
- 类型
float
- patience_count#
自上次改进更新以来的步骤数。
- 类型
int
- should_stop#
训练循环是否应该停止以避免过度拟合。
- 类型
bool
- has_improved#
指标是否在最后一次
.update
调用中比min_delta
大或等于min_delta
。- 类型
bool
通用工具#
- flax.training.common_utils.shard(xs)[source]#
pmap 的辅助函数,用于通过 local_device_count 对数组的 pytree 进行分片。
- 参数
xs – 数组的 pytree。
- 返回值
与之匹配的 pytree,其中数组的领先维度通过本地设备数量进行分片。
- flax.training.common_utils.shard_prng_key(prng_key)[source]#
用于对 PRNGKey 进行分片(也称为拆分)的辅助函数,以便与 pmap 的函数一起使用。
PRNG 密钥可以在训练时用于驱动随机模块,例如 Dropout。我们希望每个本地设备都有一个不同的 PRNG 密钥,以便我们在每个设备上获得不同的随机数,因此我们拆分了 PRNG 密钥。
- 参数
prng_key – JAX PRNGKey
- 返回值
一个新的 PRNGKeys 数组,其领先维度等于本地设备数量。
- flax.training.common_utils.stack_forest(forest)[source]#
堆叠一系列 pytree 的叶子的辅助函数。
- 参数
forest – 匹配结构的一系列 pytree(例如元组或列表),其叶子是具有单独匹配形状的数组。
- 返回值
- 单个具有相同结构的 pytree,其叶子是单独
堆叠的数组。
- flax.training.common_utils.get_metrics(device_metrics)[source]#
pmap 的辅助工具,用于收集复制的时间序列指标数据。
- 参数
device_metrics – 复制的、设备驻留的指标数据的 pytree,其叶子被假定为随着时间推移记录的数组序列。
- 返回值
一个未复制的、主机驻留的、随时间推移堆叠的数组的 pytree,可用于计算主机本地统计信息和日志记录。
- flax.training.common_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#
创建索引数组的密集单热版本。
注意:考虑使用更标准的
jax.nn.one_hot
代替。- 参数
labels – 一个 n 维 JAX 数组,其最后维度包含整数索引。
num_classes – 最大可能的索引。
on_value – 单热数组的“on”值,默认为 1.0。
off_value – 单热数组的“off”值,默认为 0.0。
- 返回值
一个 (n+1) 维数组,其最后维度包含长度为 num_classes 的单热向量。