setup vs compact#

在 Flax 的模块系统(称为 Linen)中,子模块和变量(参数或其他)可以通过两种方式定义

  1. **显式地**(使用 setup

    setup 方法中将子模块或变量分配给 self.<attr>。然后在类中定义的任何“前向传递”方法中使用分配给 self.<attr> 的子模块和变量。这类似于在 PyTorch 中定义模块的方式。

  2. **内联**(使用 nn.compact

    使用 nn.compact 注释的单个“前向传递”方法中直接编写网络的逻辑。这允许您在一个方法中定义整个模块,并将子模块和变量“共置”在它们被使用的位置旁边。

这两种方法都是完全有效的,行为相同,并且与 Flax 中的所有内容互操作.

以下是用两种方式定义的模块的简短示例,它们的功能完全相同。

class MLP(nn.Module):
  def setup(self):
    # Submodule names are derived by the attributes you assign to. In this
    # case, "dense1" and "dense2". This follows the logic in PyTorch.
    self.dense1 = nn.Dense(32)
    self.dense2 = nn.Dense(32)

  def __call__(self, x):
    x = self.dense1(x)
    x = nn.relu(x)
    x = self.dense2(x)
    return x
class MLP(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(32, name="dense1")(x)
    x = nn.relu(x)
    x = nn.Dense(32, name="dense2")(x)
    return x

那么,如何决定使用哪种风格呢?这可能是一个口味问题,但这里有一些优缺点

偏好使用 nn.compact 的原因:#

  1. 允许在子模块、参数和其他变量被使用的位置旁边定义它们:减少上下滚动以查看所有内容是如何定义的。

  2. 当有条件定义子模块、参数或变量的条件语句或循环时,可以减少代码重复。

  3. 代码通常看起来更像数学符号:y = self.param('W', ...) @ x + self.param('b', ...) 类似于 \(y=Wx+b\))

  4. 如果您使用的是形状推断,即使用形状/值依赖于输入形状(在初始化时未知)的参数,则无法使用 setup

偏好使用 setup 的原因:#

  1. 更接近 PyTorch 的约定,因此在从 PyTorch 移植模型时更容易

  2. 有些人发现将子模块和变量的定义与其使用位置明确分开更自然

  3. 允许定义多个“前向传递”方法(参见 MultipleMethodsCompactError