处理 Flax 模块参数#
介绍#
在 Flax Linen 中,我们可以将 Module
参数定义为数据类属性或作为方法参数(通常是 __call__
)。通常,区别很明显
完全固定的属性,例如内核初始化器的选择或输出特征的数量,是超参数,应定义为数据类属性。通常,两个具有不同超参数的模块实例不能以有意义的方式共享。
动态属性,例如输入数据和顶级“模式开关”如
train=True/False
,应作为参数传递给__call__
或其他方法。
然而,有些情况并不那么清晰。例如,Dropout
模块。我们有一些明确的超参数
丢弃率
为其生成丢弃掩码的轴
以及一些明确的调用时参数
应使用丢弃进行掩码的输入
用于采样随机掩码的(可选)rng
但是,有一个属性很模糊——Dropout
模块中的 deterministic
属性。
如果 deterministic
为 True
,则不会采样丢弃掩码。这通常在模型评估期间使用。但是,如果我们将 eval=True
或 train=False
传递给顶级模块。则需要将 deterministic
参数应用到所有地方,并且需要将布尔值参数传递给所有可能使用 Dropout
的层。相反,如果 deterministic
是一个数据类属性,我们可能会执行以下操作
from functools import partial
from flax import linen as nn
class ResidualModel(nn.Module):
drop_rate: float
@nn.compact
def __call__(self, x, *, train):
dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
for i in range(10):
x += ResidualBlock(dropout=dropout, ...)(x)
将 determinstic
传递给构造函数是有意义的,因为这样我们就可以将丢弃模板传递给子模块。现在,子模块不再需要处理训练与评估模式,并且可以直接使用 dropout
参数。请注意,由于丢弃层只能在子模块中构建,因此我们只能将 deterministic
部分应用于构造函数,而不能应用于 __call__
。
但是,如果 deterministic
是一个数据类属性,我们在使用设置模式时会遇到麻烦。我们 **希望** 将模块代码写成这样
class SomeModule(nn.Module):
drop_rate: float
def setup(self):
self.dropout = nn.Dropout(rate=self.drop_rate)
@nn.compact
def __call__(self, x, *, train):
# ...
x = self.dropout(x, deterministic=not train)
# ...
但是,如上所述,deterministic
将是一个属性,因此这行不通。在这里,在 __call__
期间传递 deterministic
是有意义的,因为它依赖于 train
参数。
解决方案#
我们可以通过允许某些属性作为数据类属性或方法参数传递(但不能同时传递!)来支持上述两种用例。这可以通过以下方式实现
class MyDropout(nn.Module):
drop_rate: float
deterministic: Optional[bool] = None
@nn.compact
def __call__(self, x, deterministic=None):
deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
# ...
在这个示例中,nn.merge_param
将确保 self.deterministic
或 deterministic
设置,但不能同时设置。如果两个值都为 None
或两个值都不为 None
,则会引发错误。这避免了令人困惑的行为,即代码的两个不同部分设置了相同的参数,其中一个被另一个覆盖。它还避免了默认值,该默认值可能会导致训练过程的训练步骤或评估步骤默认情况下被破坏。
功能核心#
功能核心定义的是函数,而不是类。因此,超参数和调用时参数之间没有明确的区别。预先确定超参数的唯一方法是使用 partial
。另一方面,没有方法参数也可能是属性的模糊情况。