从 Haiku 迁移到 Flax#
本指南将逐步介绍将 Haiku 模型迁移到 Flax 的过程,并重点介绍这两个库之间的差异。
基本示例#
要创建自定义模块,您需要在 Haiku 和 Flax 中都从 Module
基类进行子类化。但是,Haiku 类使用常规的 __init__
方法,而 Flax 类是 dataclasses
,这意味着您定义了一些用于自动生成构造函数的类属性。此外,所有 Flax 模块都接受一个 name
参数,无需定义它,而在 Haiku 中,name
必须在构造函数签名中显式定义并传递给超类构造函数。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
__call__
方法在这两个库中看起来非常相似,但是,在 Flax 中,您必须使用 @nn.compact
装饰器才能在内联定义子模块。在 Haiku 中,这是默认行为。
现在,Haiku 和 Flax 在构建模型方面存在很大差异。在 Haiku 中,您使用 hk.transform
对调用模块的函数进行转换,transform
将返回一个具有 init
和 apply
方法的对象。在 Flax 中,您只需实例化您的模块即可。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
...
model = Model(256, 10)
要获取这两个库中的模型参数,请使用 init
方法,并使用 random.key
加上一些输入来运行模型。这里的主要区别是 Flax 返回一个从集合名称到嵌套数组字典的映射,params
只是这些可能的集合之一。在 Haiku 中,您可以直接获得 params
结构。
sample_x = jax.numpy.ones((1, 784))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables["params"]
需要注意的一件非常重要的事情是,在 Flax 中,参数结构是分层的,每个嵌套模块有一层,最后一层用于参数名称。在 Haiku 中,参数结构是一个 Python 字典,具有两级层次结构:完全限定的模块名称映射到参数名称。模块名称由 /
分隔的字符串路径组成,该路径包含所有嵌套模块。
...
{
'model/block/linear': {
'b': (256,),
'w': (784, 256),
},
'model/linear': {
'b': (10,),
'w': (256, 10),
}
}
...
FrozenDict({
Block_0: {
Dense_0: {
bias: (256,),
kernel: (784, 256),
},
},
Dense_0: {
bias: (10,),
kernel: (256, 10),
},
})
在两个框架中的训练过程中,您将参数结构传递给 apply
方法以运行正向传递。由于我们使用的是丢弃,因此在这两种情况下,我们都必须向 apply
提供一个 key
以便生成随机丢弃掩码。
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params,
key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
最显著的差异是,在 Flax 中,您必须将参数放在带有 params
键的字典中,并将键放在带有 dropout
键的字典中。这是因为在 Flax 中,您可以拥有多种类型的模型状态和随机状态。在 Haiku 中,您只需直接传递参数和键。
处理状态#
现在让我们看看这两个库是如何处理可变状态的。我们将使用与之前相同的模型,但现在我们将用批量归一化替换丢弃。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(
momentum=0.99
)(x, use_running_average=not training)
x = jax.nn.relu(x)
return x
在这种情况下,代码非常相似,因为两个库都提供了一个批量归一化层。最显著的差异是 Haiku 使用 is_training
来控制是否更新运行统计信息,而 Flax 使用 use_running_average
来实现相同目的。
要在 Haiku 中实例化有状态模型,您需要使用 hk.transform_with_state
,它会更改 init
和 apply
的签名以接受和返回状态。与之前一样,在 Flax 中,您直接构建模块。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
...
model = Model(256, 10)
要初始化参数和状态,您只需像以前一样调用 init
方法即可。但是,在 Haiku 中,您现在会获得 state
作为第二个返回值,而在 Flax 中,您会在 variables
字典中获得一个新的 batch_stats
集合。请注意,由于 hk.BatchNorm
仅在 is_training=True
时初始化批量统计信息,因此我们在初始化具有 hk.BatchNorm
层的 Haiku 模型的参数时必须设置 training=True
。在 Flax 中,我们可以像往常一样设置 training=False
。
sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
random.key(0),
sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
通常,在 Flax 中,您可能会在 variables
字典中找到其他状态集合,例如 cache
(用于自回归 Transformer 模型)、intermediates
(用于使用 Module.sow
添加的中间值)或其他由自定义层定义的集合名称。Haiku 仅区分 params
(在运行 apply
时不会改变的变量)和 state
(在运行 apply
时可能会改变的变量),并使用 hk.transform
或 hk.transform_with_state
。
现在,训练在这两个框架中看起来非常相似,因为您使用相同的 apply
方法来运行正向传递。在 Haiku 中,现在将 state
作为第二个参数传递给 apply
,并将新状态作为第二个返回值获得。在 Flax 中,您改为将 batch_stats
作为新键添加到输入字典中,并将 updates
变量字典作为第二个返回值获得。
def train_step(params, state, inputs, labels):
def loss_fn(params):
logits, new_state = model.apply(
params, state,
None, # <== rng
inputs, training=True # <== inputs
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, new_state
grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, new_state
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params):
logits, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
mutable='batch_stats',
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates["batch_stats"]
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
一个主要区别是,在 Flax 中,状态集合可以是可变的,也可以是不可变的。在 init
期间,默认情况下所有集合都是可变的,但是,在 apply
期间,您必须显式指定哪些集合是可变的。在本例中,我们指定 batch_stats
是可变的。这里传递的是单个字符串,但如果存在更多可变集合,则也可以传递列表。如果未执行此操作,则在尝试更改 batch_stats
时,将在运行时引发错误。此外,当 mutable
不是 False
时,updates
字典将作为 apply
的第二个返回值返回,否则仅返回模型输出。Haiku 通过使用 params
(不可变)和 state
(可变)以及使用 hk.transform
或 hk.transform_with_state
来区分可变/不可变。
使用多个方法#
在本节中,我们将了解如何在 Haiku 和 Flax 中使用多个方法。例如,我们将实现一个具有三个方法的自编码器模型:encode
、decode
和 __call__
。
在 Haiku 中,我们只需直接在 __init__
中定义 encode
和 decode
所需的子模块,在本例中,每个模块都将使用 Linear
层。在 Flax 中,我们将在 setup
中提前定义 encoder
和 decoder
模块,并在 encode
和 decode
中分别使用它们。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
请注意,在 Flax 中,setup
不会在 __init__
之后运行,而是在调用 init
或 apply
时运行。
现在,我们希望能够从我们的 AutoEncoder
模型中调用任何方法。在 Haiku 中,我们可以通过 hk.multi_transform
为模块定义多个 apply
方法。传递给 multi_transform
的函数定义了如何初始化模块以及要生成哪些不同的应用方法。
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
...
model = AutoEncoder(256, 784)
为了初始化模型的参数,可以使用 init
触发 __call__
方法,该方法同时使用 encode
和 decode
方法。这将创建模型所需的所有参数。
params = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
params = variables["params"]
这将生成以下参数结构。
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
FrozenDict({
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
})
最后,让我们探索如何使用 apply
函数调用 encode
方法。
encode, decode = model.apply
z = encode(
params,
None, # <== rng
x=jax.numpy.ones((1, 784)),
)
...
z = model.apply(
{"params": params},
x=jax.numpy.ones((1, 784)),
method="encode",
)
由于 Haiku apply
函数是通过 hk.multi_transform
生成的,它是一个包含两个函数的元组,我们可以将其解包为一个 encode
函数和一个 decode
函数,它们对应于 AutoEncoder
模块上的方法。在 Flax 中,我们通过将方法名称作为字符串传递来调用 encode
方法。这里另一个值得注意的区别是,在 Haiku 中,即使模块在 apply
期间没有使用任何随机操作,也需要显式地传递 rng
。在 Flax 中,这不是必需的(请查看 Flax 中的随机性和 PRNG)。这里的 Haiku rng
设置为 None
,但你也可以在 apply
函数上使用 hk.without_apply_rng
来删除 rng
参数。
提升的转换#
Flax 和 Haiku 都提供了一组转换,我们将它们称为提升的转换,这些转换以一种可以与模块一起使用的方式包装 JAX 转换,并且有时会提供额外的功能。在本节中,我们将了解如何在 Flax 和 Haiku 中使用 scan
的提升版本来实现一个简单的 RNN 层。
首先,我们将定义一个 RNNCell
模块,该模块将包含 RNN 单步的逻辑。我们还将定义一个 initial_state
方法,该方法将用于初始化 RNN 的状态(也称为 carry
)。与 jax.lax.scan
一样,RNNCell.__call__
方法将是一个函数,它接收传递和输入,并返回新的传递和输出。在这种情况下,传递和输出是相同的。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,我们将定义一个 RNN
模块,该模块将包含整个 RNN 的逻辑。在 Haiku 中,我们将首先初始化 RNNCell
,然后使用它来构建 carry
,最后使用 hk.scan
在输入序列上运行 RNNCell
。在 Flax 中,它的做法略有不同,我们将使用 nn.scan
来定义一个新的临时类型,该类型包装 RNNCell
。在此过程中,我们还将指定指示 nn.scan
广播 params
集合(所有步骤共享相同的参数)并且不拆分 params
rng 流(以便所有步骤使用相同的参数进行初始化),最后,我们将指定我们希望 scan 在输入的第二个轴上运行,并将输出也沿着第二个轴堆叠起来。然后,我们将立即使用此临时类型来创建提升的 RNNCell
的实例,并使用它来创建 carry
并运行 __call__
方法,该方法将在序列上进行 scan
。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
in_axes=1, out_axes=1)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
总的来说,Flax 和 Haiku 之间提升的转换的主要区别在于,在 Haiku 中,提升的转换不会对状态进行操作,也就是说,Haiku 将以一种在转换内外保持相同形状的方式处理 params
和 state
。在 Flax 中,提升的转换可以对变量集合和 rng 流进行操作,用户必须根据转换的语义定义每个转换如何处理不同的集合。
最后,让我们快速查看如何在 Haiku 和 Flax 中使用 RNN
模块。
def forward(x):
return RNN(64)(x)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
y = model.apply(
params,
x=jax.numpy.ones((3, 12, 32)),
)
...
model = RNN(64)
variables = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
{'params': params},
x=jax.numpy.ones((3, 12, 32)),
)
与前面部分中的示例相比,唯一值得注意的变化是这次我们在 Haiku 中使用了 hk.without_apply_rng
,因此我们不必将 rng
参数作为 None
传递给 apply
方法。
在层上进行扫描#
scan
的一个非常重要的应用是,迭代地在输入上应用一系列层,将每个层的输出作为下一个层的输入传递。这对于减少大型模型的编译时间非常有用。例如,我们将创建一个简单的 Block
模块,然后将其用在 MLP
模块中,该模块将应用 num_layers
次的 Block
模块。
在 Haiku 中,我们像往常一样定义 Block
模块,然后在 MLP
中,我们将在 stack_block
函数上使用 hk.experimental.layer_stack
来创建一个 Block
模块堆栈。在 Flax 中,Block
的定义略有不同,__call__
将接收和返回一个第二个虚拟输入/输出,它们在两种情况下都将为 None
。在 MLP
中,我们将像在前面的示例中一样使用 nn.scan
,但通过设置 split_rngs={'params': True}
和 variable_axes={'params': 0}
,我们告诉 nn.scan
为每个步骤创建不同的参数并沿着第一个轴切片 params
集合,从而有效地实现一个 Block
模块堆栈,就像在 Haiku 中一样。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
请注意,在 Flax 中,我们如何将 None
作为第二个参数传递给 ScanBlock
并忽略其第二个输出。这些表示每个步骤的输入/输出,但它们是 None
,因为在这种情况下,我们没有任何输入/输出。
初始化每个模型与前面的示例相同。在这种情况下,我们将指定我们希望使用 5
个层,每个层具有 64
个特征。
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jax.numpy.ones((1, 64))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
...
model = MLP(64, num_layers=5)
sample_x = jax.numpy.ones((1, 64))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables['params']
当在层上使用 scan 时,你应该注意的是,所有层都融合成一个单独的层,其参数在第一个轴上有一个额外的“层”维度。在这种情况下,所有参数的形状都将以 (5, ...)
开头,因为我们使用的是 5
个层。
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
FrozenDict({
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
})
顶级 Haiku 函数与顶级 Flax 模块#
在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state}
来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶级“模块”编写为函数非常常见。
Flax 团队建议使用更以模块为中心的 approach,该 approach 使用 __call__ 来定义前向函数。相应的访问器将是 nn.module.param 和 nn.module.variable(有关集合的解释,请转到 处理状态)。
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output
model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']