常见问题解答 (FAQ)#

这是一些常见问题解答 (FAQ) 的答案集合。您可以通过在 GitHub Discussions 中发起新主题来为 Flax FAQ 贡献内容。

如何对中间值求导(使用 Module.perturb)?#

要对模型层内隐藏/中间激活的输出求导数/梯度,可以使用 flax.linen.Module.perturb()。您在正向传递中定义一个与中间激活形状相同的零值 flax.linen.Module“扰动”参数 - perturb(...) -,使用 'perturbations' 作为附加的独立参数定义损失函数,使用 jax.grad 对扰动参数执行 JAX 导数运算。

有关完整示例和详细文档,请访问

Flax Linen 的 remat_scan()scan(remat(...)) 相同吗?#

Flax 的 remat_scan() (flax.linen.remat_scan()) 和 scan(remat(...)) (flax.linen.scan() 应用于 flax.linen.remat()) 并不相同,并且 remat_scan() 在其支持的用例中受到限制。具体来说,remat_scan() 将输入和输出视为传递(贯穿训练循环的隐藏状态)。建议您使用 scan(remat(...)),因为通常您需要额外的参数,例如 in_axes(输入数组轴)或 out_axes(输出数组轴),而 flax.linen.remat_scan() 未公开这些参数。