重构 FLIP 中的 RNNCellBase#
作者:Cristian Garcia、Marcus Chiam、Jasmijn Bastings
开始日期:2023 年 5 月 1 日
FLIP 问题: [待定]
FLIP PR:#3053
状态:已实现
摘要#
本提案旨在通过重构 initialize_carry
方法和其他相关组件来提高 RNNCellBase
类的可用性。
动机#
目前,initialize_carry
用于初始化 carry 并传递关键元数据,例如特征数量。API 可能会不直观,因为它要求用户手动计算通常可以由模块本身推断的项目,例如批量维度和特征维度的形状。
示例:ConvLSTM#
在 ConvLSTM
等情况下,当前 API 可能会不直观,因为 size
参数包含输入图像形状和输出特征维度。
x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)
# image shape: vvvvvvv
carry = nn.ConvLSTMCell.initialize_carry(key1, (16,), (64, 64, 16))
# batch size: ^^ ^^ :output features
lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
(carry, y), initial_params = lstm.init_with_output(key2, carry, x)
本 FLIP 将提出对 initialize_carry
的一些更改,以便可以将上述示例简化为
x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)
lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
carry = lstm.initialize_carry(key1, input_shape=x.shape)
(carry, y), initial_params = lstm.init_with_output(key2, carry, x)
实现#
该提案建议进行以下更改
initialize_carry#
initialize_carry
应作为具有以下签名的实例方法重构
def initialize_carry(self, key, sample_input):
sample_input
应是一个与单元格处理的形状相同的数组,不包括时间轴。
重构 RNNCellBase 子类#
RNNCellBase
应重构为包含初始化单元格并执行其正向传递所需的元数据。对于 LSTMCell
和 GRUCell
,这意味着添加一个 features
属性,该属性应由用户在构造时提供。此更改与大多数其他 Module
的结构一致,使其对用户更熟悉。
x = jnp.ones((2, 100, 10)) # (batch, time, features)
cell = nn.LSTMCell(features=32)
carry = cell.initialize_carry(PRNGKey(0), x[:, 0]) # sample input
(carry, y), variables = cell.init_with_output(PRNGKey(1), carry, x)
num_feature_dims#
为了简化在 RNN
等抽象中处理 RNNCellBase
实例,每个单元格都应实现 num_feature_dims
属性。对于大多数单元格(例如 LSTMCell
和 GRUCell
),这始终为 1。对于 ConvLSTM
等单元格,这取决于它们的 kernel_size
。
讨论#
替代方法#
为了消除对
num_feature_dims
的需求,RNN
仅能支持单个批量维度,即形式为(batch, time, *features)
的输入。目前,它支持多个批量维度和多个特征维度。另一种方法可能是彻底重新设计 Flax 处理循环状态的方式。例如,
memory
集合可以作为变量的一部分处理。但是,这会带来一些挑战,例如在训练期间处理无状态单元格、将状态从一层传递到另一层,以及在scan
中执行初始化。
重构成本#
初始 TGP 结果显示 761 个测试失败,110 个测试失败。但是,在修复一个测试后,TGP 结果显示 231 个测试失败,13 个测试失败,因此失败的测试之间似乎有很多重叠。
为了最大程度地降低重构成本,将保留当前实现以供 Google 内部用户使用,名称为已弃用。这将允许用户根据自己的节奏迁移到新 API。对于开源用户,我们应该将 Flax 版本升级到 0.7.0
,以便现有用户可以继续依赖 0.6.x
版本。