flax.jax_utils 包#
我们可能考虑上游到 Jax 的实用程序。
- flax.jax_utils.partial_eval_by_shape(fn, input_spec, *args, **kwargs)[source]#
通过使用输入的形状来延迟评估函数。
此函数类似于
jax.eval_shape
,主要区别在于可以无需输入的具体值即可计算的函数输出将按原样返回,而不是仅返回形状。例如,请参阅module.init_by_shape
,其中此功能用于在不使用输入数据 lr 计算的情况下初始化模型。- 参数
fn – 要延迟评估的函数。
input_spec – 形状或 (形状,数据类型) 元组的可迭代对象,指定输入的形状和类型。如果未指定,则数据类型为 float32。
*args – 传递给模块的 apply 函数的其他参数
**kwargs – 传递给模块的 apply 函数的关键字参数
- 返回
包含模型输出和 Model 实例的一对
多设备实用程序#
- flax.jax_utils.replicate(tree, devices=None)[source]#
将数组复制到多个设备。
- 参数
tree – 包含应复制的数组的 pytree。
devices – 数据复制到的设备(默认值:与
jax.pmap()
预期的顺序相同)。
- 返回
包含复制数组的新 pytree。
- flax.jax_utils.prefetch_to_device(iterator, size, devices=None)[source]#
在设备上分片和预取批次。
此实用程序采用一个迭代器并返回一个新迭代器,该迭代器填充设备上的预取缓冲区。提前预取可以通过重叠计算和数据传输来显著提高训练循环的性能。
此实用程序主要适用于 GPU,对于 TPU 和 CPU,则无需使用 - TPU 和 CPU 内存分配器(通常)不会选择尚未空闲的内存位置,因此不会阻塞。相反,这些分配器会发生 OOM。
- 参数
iterator – 一个迭代器,它生成一个 ndarray 的 pytree,其中第一个维度在设备之间分片。
size –
预取缓冲区的大小。
如果您在 GPU 上进行训练,则 2 通常是最佳选择,因为这保证您可以将 GPU 上的训练步骤与 CPU 上的数据预取步骤重叠。
devices –
应将数组预取到的设备列表。
默认为
jax.pmap
预期的设备顺序。
- 产量
迭代器中的原始项,其中每个 ndarray 现在都分片到指定的设备。
- flax.jax_utils.pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=(), static_return=False)[source]#
用填充、分片、然后取消分片、取消填充的代码包装函数。
- 参数
wrapped – 要包装的函数。签名为
params, *args, *kwargs
。static_argnums – 参数在
wrapped
中的索引,这些参数不应该被填充和分片,而应该按原样转发。默认为 (0,),因为到目前为止,最常见的用例是首先传递params
。static_argnames –
wrapped
的 kwargs 的名称,这些 kwargs 不应该被填充和分片,而应该按原样转发。static_return – 是否不取消分片和取消填充返回值;静态返回值通常与计算指标的 eval 步骤一起使用
- 返回
一个新函数,在将参数传递给包装函数之前填充和分片其参数,然后取消分片和取消填充返回的 pytree。
这对于使用不能被设备数量整除的输入调用 pmap’ed 函数很有用。一个典型的用法是
@pad_shard_unpad @jax.pmap def forward(params, x): …
笔记
填充在主机内存中完成,然后传递给函数,函数返回的值被传回主机内存。
返回的函数用一个新的仅关键字参数
min_device_batch
进行增强,如果指定了该参数,则强制将输入填充到每个设备至少此大小。这对于避免最后批次的重新编译并减少内存碎片很有用。有关更多信息,请参阅 https://flax.org.cn/en/latest/guides/data_preprocessing/full_eval.html