setup
vs compact
#
在 Flax 的模块系统(称为 Linen)中,子模块和变量(参数或其他)可以通过两种方式定义
**显式地**(使用
setup
)在
setup
方法中将子模块或变量分配给self.<attr>
。然后在类中定义的任何“前向传递”方法中使用分配给self.<attr>
的子模块和变量。这类似于在 PyTorch 中定义模块的方式。**内联**(使用
nn.compact
)使用
nn.compact
注释的单个“前向传递”方法中直接编写网络的逻辑。这允许您在一个方法中定义整个模块,并将子模块和变量“共置”在它们被使用的位置旁边。
这两种方法都是完全有效的,行为相同,并且与 Flax 中的所有内容互操作.
以下是用两种方式定义的模块的简短示例,它们的功能完全相同。
class MLP(nn.Module):
def setup(self):
# Submodule names are derived by the attributes you assign to. In this
# case, "dense1" and "dense2". This follows the logic in PyTorch.
self.dense1 = nn.Dense(32)
self.dense2 = nn.Dense(32)
def __call__(self, x):
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
return x
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(32, name="dense1")(x)
x = nn.relu(x)
x = nn.Dense(32, name="dense2")(x)
return x
那么,如何决定使用哪种风格呢?这可能是一个口味问题,但这里有一些优缺点
偏好使用 nn.compact
的原因:#
允许在子模块、参数和其他变量被使用的位置旁边定义它们:减少上下滚动以查看所有内容是如何定义的。
当有条件定义子模块、参数或变量的条件语句或循环时,可以减少代码重复。
代码通常看起来更像数学符号:
y = self.param('W', ...) @ x + self.param('b', ...)
类似于 \(y=Wx+b\))如果您使用的是形状推断,即使用形状/值依赖于输入形状(在初始化时未知)的参数,则无法使用
setup
。
偏好使用 setup
的原因:#
更接近 PyTorch 的约定,因此在从 PyTorch 移植模型时更容易
有些人发现将子模块和变量的定义与其使用位置明确分开更自然
允许定义多个“前向传递”方法(参见
MultipleMethodsCompactError
)