常见问题解答 (FAQ)#
这是一些常见问题解答 (FAQ) 的答案集合。您可以通过在 GitHub Discussions 中发起新主题来为 Flax FAQ 贡献内容。
如何对中间值求导(使用 Module.perturb
)?#
要对模型层内隐藏/中间激活的输出求导数/梯度,可以使用 flax.linen.Module.perturb()
。您在正向传递中定义一个与中间激活形状相同的零值 flax.linen.Module
“扰动”参数 - perturb(...)
-,使用 'perturbations'
作为附加的独立参数定义损失函数,使用 jax.grad
对扰动参数执行 JAX 导数运算。
有关完整示例和详细文档,请访问
Flax Linen 的 remat_scan()
与 scan(remat(...))
相同吗?#
Flax 的 remat_scan()
(flax.linen.remat_scan()
) 和 scan(remat(...))
(flax.linen.scan()
应用于 flax.linen.remat()
) 并不相同,并且 remat_scan()
在其支持的用例中受到限制。具体来说,remat_scan()
将输入和输出视为传递(贯穿训练循环的隐藏状态)。建议您使用 scan(remat(...))
,因为通常您需要额外的参数,例如 in_axes
(输入数组轴)或 out_axes
(输出数组轴),而 flax.linen.remat_scan()
未公开这些参数。
推荐的训练循环库有哪些?#
考虑使用 CLU (Common Loop Utils) google/CommonLoopUtils。要开始使用,请访问此 CLU 概要 Colab。您可以在 google/flax GitHub Discussions 上找到有关 CLU 与 Flax 的常见问题的解答。
查看官方的 google/flax 示例,了解如何使用训练循环和 (CLU) 指标的示例。例如,这是 Flax ImageNet 的 train.py。
对于计算机视觉研究,请考虑使用 google-research/scenic。Scenic 是一组共享的轻量级库,用于解决训练大规模视觉模型时遇到的常见任务(并提供多个项目的示例)。Scenic 使用 JAX 和 Flax 开发。要开始使用,请访问 GitHub 上的 README 页面。