重构 FLIP 中的 RNNCellBase#

作者:Cristian Garcia、Marcus Chiam、Jasmijn Bastings

  • 开始日期:2023 年 5 月 1 日

  • FLIP 问题: [待定]

  • FLIP PR:#3053

  • 状态:已实现

摘要#

本提案旨在通过重构 initialize_carry 方法和其他相关组件来提高 RNNCellBase 类的可用性。

动机#

目前,initialize_carry 用于初始化 carry 并传递关键元数据,例如特征数量。API 可能会不直观,因为它要求用户手动计算通常可以由模块本身推断的项目,例如批量维度和特征维度的形状。

示例:ConvLSTM#

ConvLSTM 等情况下,当前 API 可能会不直观,因为 size 参数包含输入图像形状和输出特征维度。

x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)

#                                        image shape: vvvvvvv
carry = nn.ConvLSTMCell.initialize_carry(key1, (16,), (64, 64, 16))
#                                   batch size: ^^             ^^ :output features

lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
(carry, y), initial_params = lstm.init_with_output(key2, carry, x)

本 FLIP 将提出对 initialize_carry 的一些更改,以便可以将上述示例简化为

x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)

lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
carry = lstm.initialize_carry(key1, input_shape=x.shape)

(carry, y), initial_params = lstm.init_with_output(key2, carry, x)

实现#

该提案建议进行以下更改

initialize_carry#

initialize_carry 应作为具有以下签名的实例方法重构

def initialize_carry(self, key, sample_input):

sample_input 应是一个与单元格处理的形状相同的数组,不包括时间轴。

重构 RNNCellBase 子类#

RNNCellBase 应重构为包含初始化单元格并执行其正向传递所需的元数据。对于 LSTMCellGRUCell,这意味着添加一个 features 属性,该属性应由用户在构造时提供。此更改与大多数其他 Module 的结构一致,使其对用户更熟悉。

x = jnp.ones((2, 100, 10)) # (batch, time, features)

cell = nn.LSTMCell(features=32)
carry = cell.initialize_carry(PRNGKey(0), x[:, 0]) # sample input

(carry, y), variables = cell.init_with_output(PRNGKey(1), carry, x)

num_feature_dims#

为了简化在 RNN 等抽象中处理 RNNCellBase 实例,每个单元格都应实现 num_feature_dims 属性。对于大多数单元格(例如 LSTMCellGRUCell),这始终为 1。对于 ConvLSTM 等单元格,这取决于它们的 kernel_size

讨论#

替代方法#

  • 为了消除对 num_feature_dims 的需求,RNN 仅能支持单个批量维度,即形式为 (batch, time, *features) 的输入。目前,它支持多个批量维度和多个特征维度。

  • 另一种方法可能是彻底重新设计 Flax 处理循环状态的方式。例如,memory 集合可以作为变量的一部分处理。但是,这会带来一些挑战,例如在训练期间处理无状态单元格、将状态从一层传递到另一层,以及在 scan 中执行初始化。

重构成本#

初始 TGP 结果显示 761 个测试失败,110 个测试失败。但是,在修复一个测试后,TGP 结果显示 231 个测试失败,13 个测试失败,因此失败的测试之间似乎有很多重叠。

为了最大程度地降低重构成本,将保留当前实现以供 Google 内部用户使用,名称为已弃用。这将允许用户根据自己的节奏迁移到新 API。对于开源用户,我们应该将 Flax 版本升级到 0.7.0,以便现有用户可以继续依赖 0.6.x 版本。