处理 Flax 模块参数

处理 Flax 模块参数#

介绍#

在 Flax Linen 中,我们可以将 Module 参数定义为数据类属性或作为方法参数(通常是 __call__)。通常,区别很明显

  • 完全固定的属性,例如内核初始化器的选择或输出特征的数量,是超参数,应定义为数据类属性。通常,两个具有不同超参数的模块实例不能以有意义的方式共享。

  • 动态属性,例如输入数据和顶级“模式开关”如 train=True/False,应作为参数传递给 __call__ 或其他方法。

然而,有些情况并不那么清晰。例如,Dropout 模块。我们有一些明确的超参数

  1. 丢弃率

  2. 为其生成丢弃掩码的轴

以及一些明确的调用时参数

  1. 应使用丢弃进行掩码的输入

  2. 用于采样随机掩码的(可选)rng

但是,有一个属性很模糊——Dropout 模块中的 deterministic 属性。

如果 deterministicTrue,则不会采样丢弃掩码。这通常在模型评估期间使用。但是,如果我们将 eval=Truetrain=False 传递给顶级模块。则需要将 deterministic 参数应用到所有地方,并且需要将布尔值参数传递给所有可能使用 Dropout 的层。相反,如果 deterministic 是一个数据类属性,我们可能会执行以下操作

from functools import partial
from flax import linen as nn

class ResidualModel(nn.Module):
  drop_rate: float

  @nn.compact
  def __call__(self, x, *, train):
    dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
    for i in range(10):
      x += ResidualBlock(dropout=dropout, ...)(x)

determinstic 传递给构造函数是有意义的,因为这样我们就可以将丢弃模板传递给子模块。现在,子模块不再需要处理训练与评估模式,并且可以直接使用 dropout 参数。请注意,由于丢弃层只能在子模块中构建,因此我们只能将 deterministic 部分应用于构造函数,而不能应用于 __call__

但是,如果 deterministic 是一个数据类属性,我们在使用设置模式时会遇到麻烦。我们 **希望** 将模块代码写成这样

class SomeModule(nn.Module):
  drop_rate: float

  def setup(self):
    self.dropout = nn.Dropout(rate=self.drop_rate)

  @nn.compact
  def __call__(self, x, *, train):
    # ...
    x = self.dropout(x, deterministic=not train)
    # ...

但是,如上所述,deterministic 将是一个属性,因此这行不通。在这里,在 __call__ 期间传递 deterministic 是有意义的,因为它依赖于 train 参数。

解决方案#

我们可以通过允许某些属性作为数据类属性或方法参数传递(但不能同时传递!)来支持上述两种用例。这可以通过以下方式实现

class MyDropout(nn.Module):
  drop_rate: float
  deterministic: Optional[bool] = None

  @nn.compact
  def __call__(self, x, deterministic=None):
    deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
    # ...

在这个示例中,nn.merge_param 将确保 self.deterministicdeterministic 设置,但不能同时设置。如果两个值都为 None 或两个值都不为 None,则会引发错误。这避免了令人困惑的行为,即代码的两个不同部分设置了相同的参数,其中一个被另一个覆盖。它还避免了默认值,该默认值可能会导致训练过程的训练步骤或评估步骤默认情况下被破坏。

功能核心#

功能核心定义的是函数,而不是类。因此,超参数和调用时参数之间没有明确的区别。预先确定超参数的唯一方法是使用 partial。另一方面,没有方法参数也可能是属性的模糊情况。