从 Haiku 迁移到 Flax#
本指南演示了 Haiku 和 Flax NNX 模型之间的差异,并提供了并排的示例代码,以帮助您从 Haiku 迁移到 Flax NNX API。
如果您是 Flax NNX 的新手,请务必熟悉 Flax NNX 基础知识,其中涵盖了 nnx.Module
系统、Flax 转换和 功能性 API 以及示例。
让我们从一些导入开始。
基本模块定义#
Haiku 和 Flax 都使用 Module
类作为表达神经网络库层的默认单元。例如,要创建一个具有 dropout 和 ReLU 激活函数的单层网络,您需要
首先,创建一个
Block
(通过子类化Module
),它由一个带有 dropout 和 ReLU 激活函数的线性层组成。然后,在创建
Model
(也通过子类化Module
)时,将Block
用作子Module
,该Model
由Block
和一个线性层组成。
Haiku 和 Flax Module
对象之间有两个根本区别
无状态 vs. 有状态:
haiku.Module
实例是无状态的。这意味着,变量是从纯功能性Module.init()
调用返回并单独管理的。然而,
flax.nnx.Module
将其变量作为此 Python 对象的属性拥有。
延迟 vs. 立即:
haiku.Module
仅在用户调用模型时实际看到输入时才分配空间来创建变量(延迟)。flax.nnx.Module
实例在实例化时(在看到样本输入之前)创建变量(立即)。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
变量创建#
本节介绍实例化模型和初始化其参数。
要为 Haiku 模型生成模型参数,您需要将其放入前向函数中,并使用
haiku.transform
使其成为纯功能性。这会导致一个 JAX 数组(jax.Array
数据类型)的嵌套字典,需要单独携带和维护。在 Flax NNX 中,当您实例化模型时,会自动初始化模型参数,并且变量(
nnx.Variable
对象)作为属性存储在nnx.Module
(或其子模块)中。您仍然需要为它提供一个 伪随机数生成器 (PRNG) 密钥,但该密钥将包装在nnx.Rngs
类中并存储在其中,并在需要时生成更多 PRNG 密钥。
如果您想以无状态的、类似字典的方式访问 Flax 模型参数以进行检查点保存或模型手术,请查看 Flax NNX 分割/合并 API(nnx.split
/ nnx.merge
)。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 784))
params = model.init(jax.random.key(0), sample_x, training=False)
assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
...
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# Parameters were already initialized during model instantiation.
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
训练步骤和编译#
本节介绍如何编写训练步骤并使用 JAX 即时编译对其进行编译。
编译训练步骤时
Haiku 使用
@jax.jit
- 一个 JAX 转换 - 来编译一个纯功能性训练步骤。Flax NNX 使用
@nnx.jit
- 一个 Flax NNX 转换(几个转换 API 之一,其行为类似于 JAX 转换,但也 可以很好地与 Flax 对象一起工作)。虽然jax.jit
只接受具有纯无状态参数的函数,但flax.nnx.jit
允许参数为有状态的模块。这大大减少了训练步骤所需的行数。
在求梯度时
类似地,Haiku 使用
jax.grad
(用于 自动微分的 JAX 转换)来返回梯度的原始字典。同时,Flax NNX 使用
flax.nnx.grad
(Flax NNX 转换)来返回 Flax NNX 模块的梯度,作为flax.nnx.State
字典。如果您想将常规jax.grad
与 Flax NNX 一起使用,则需要使用 分割/合并 API。
对于优化器
如果您已经在使用 Optax 优化器(如
optax.adamw
)(而不是此处显示的原始jax.tree.map
计算)与 Haiku,请查看 Flax 基础知识指南中的flax.nnx.Optimizer
示例,了解一种更简洁的训练和更新模型的方法。
每次训练步骤期间的模型更新
Haiku 训练步骤需要返回一个 JAX 树的参数作为下一步的输入。
Flax NNX 训练步骤不需要返回任何内容,因为
model
已经在nnx.jit
内就地更新。此外,
nnx.Module
对象是有状态的,并且Module
会自动跟踪其中的几件事,例如 PRNG 密钥和flax.nnx.BatchNorm
统计信息。这就是为什么您无需在每一步都显式传递 PRNG 密钥的原因。另请注意,您可以使用flax.nnx.reseed
来重置其底层的 PRNG 状态。
dropout 行为
在 Haiku 中,您需要显式定义并传入
training
参数来切换haiku.dropout
,并确保只有在training=True
时才会发生随机 dropout。在 Flax NNX 中,你可以调用
model.train()
(flax.nnx.Module.train()
) 来自动将flax.nnx.Dropout
切换到训练模式。相反,你可以调用model.eval()
(flax.nnx.Module.eval()
) 来关闭训练模式。你可以在其 API 参考 中了解更多关于flax.nnx.Module.train
的信息。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params, key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # set deterministic=False
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(
inputs, # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
处理非参数状态#
Haiku 区分了可训练参数和模型跟踪的所有其他数据(“状态”)。例如,批量归一化中使用的批次统计信息被视为状态。具有状态的模型需要使用 hk.transform_with_state
进行转换,以便其 .init()
返回参数和状态。
在 Flax 中,没有如此严格的区分 - 它们都是 nnx.Variable
的子类,并被模块视为其属性。参数是名为 nnx.Param
的子类的实例,而批次统计信息可以是另一个名为 nnx.BatchStat
的子类。你可以使用 nnx.split
快速提取特定变量类型的所有数据。
让我们通过采用上面的 Block
定义,但将 dropout 替换为 BatchNorm
来看看这个例子。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(jax.random.key(0), sample_x, training=True)
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
Flax 会考虑可训练参数和其他数据之间的差异。nnx.grad
只会对 nnx.Param
变量求梯度,从而自动跳过 batchnorm
数组。因此,对于使用此模型的 Flax NNX,训练步骤将看起来相同。
使用多个方法#
在本节中,你将学习如何在 Haiku 和 Flax 中使用多个方法。例如,你将实现一个具有三个方法的自动编码器模型:encode
、decode
和 __call__
。
在 Haiku 中,你需要使用 hk.multi_transform
显式定义如何初始化模型以及它可以调用哪些方法(这里的 encode
和 decode
)。请注意,你仍然需要定义一个 __call__
,它会激活两个层,以便延迟初始化所有模型参数。
在 Flax 中,它更简单,因为你在 __init__
中初始化了参数,并且可以直接使用 nnx.Module
方法 encode
和 decode
。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
params = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
...
参数结构如下
...
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
},
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
要调用这些自定义方法
在 Haiku 中,你需要解耦 .apply 函数以在调用它之前提取你的方法。
在 Flax 中,你可以直接调用该方法。
encode, decode = model.apply
z = encode(params, None, x=jnp.ones((1, 784)))
...
z = model.encode(jnp.ones((1, 784)))
转换#
Haiku 和 Flax 转换 都提供了自己的转换集,这些转换以一种可以与 Module
对象一起使用的方式包装了 JAX 转换。
有关 Flax 转换的更多信息,请查看 转换指南。
让我们从一个例子开始
首先,定义一个
RNNCell
Module
,其中将包含 RNN 单步的逻辑。定义一个
initial_state
方法,该方法将用于初始化 RNN 的状态(又名carry
)。与jax.lax.scan
(API 文档) 一样,RNNCell.__call__
方法将是一个接收 carry 和输入并返回新的 carry 和输出的函数。在这种情况下,carry 和输出是相同的。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,我们将定义一个 RNN
模块,其中将包含整个 RNN 的逻辑。在这两种情况下,我们都使用库的 scan
调用来运行输入序列上的 RNNCell
。
唯一的区别是 Flax nnx.scan
允许你指定在参数 in_axes
和 out_axes
中重复的轴,这将转发到底层的 `jax.lax.scan<https://jax.net.cn/en/latest/_autosummary/jax.lax.scan.html>`__,而在 Haiku 中,你需要显式地转置输入和输出。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(
cell, carry,
jnp.swapaxes(x, 1, 0)
)
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
扫描层#
大多数 Haiku 转换应与 Flax 类似,因为它们都包装了它们的 JAX 对等物,但扫描层用例是一个例外。
扫描层是一种技术,你可以在其中通过一系列 N 个重复层运行输入,将每个层的输出作为下一层的输入传递。这种模式可以显著减少大型模型的编译时间。在下面的示例中,你将在顶层的 MLP
Module
中重复 Block
Module
5 次。
在 Haiku 中,我们像往常一样定义 Block
模块,然后在 MLP
内部,我们将使用 hk.experimental.layer_stack
在 stack_block
函数之上来创建 Block
模块的堆栈。相同的代码将在初始化时创建 5 层参数,并在调用时通过它们运行输入。
在 Flax 中,模型初始化和调用代码是完全解耦的,因此我们使用 nnx.vmap
转换来初始化底层 Block
参数,并使用 nnx.scan
转换来运行模型输入。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 64))
params = model.init(jax.random.key(0), sample_x, training=False)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
在上面的 Flax 示例中还有一些其他细节需要解释
`@nnx.split_rngs` 装饰器: Flax 转换与其 JAX 对等物一样,完全不了解 PRNG 状态,并且依赖于 PRNG 密钥的输入。
nnx.split_rngs
装饰器允许你在将nnx.Rngs
传递给装饰函数之前拆分它们,并在之后“降低”它们,以便它们可以在外部使用。在这里,你拆分了 PRNG 密钥,因为如果其每个内部操作都需要自己的密钥,则
jax.vmap
和jax.lax.scan
需要 PRNG 密钥列表。因此,对于MLP
内的 5 层,你在转到 JAX 转换之前拆分并从其参数提供 5 个不同的 PRNG 密钥。请注意,实际上
create_block()
知道它需要创建 5 层正是因为它看到了 5 个 PRNG 密钥,因为in_axes=(0,)
指示vmap
将查看第一个参数的第一个维度以了解它将映射的大小。对于
forward()
也是如此,它会查看第一个参数(又名model
)中的变量,以找出它需要扫描多少次。nnx.split_rngs
这里实际上拆分了model
内部的 PRNG 状态。(如果Block
Module
没有 dropout,则你不需要nnx.split_rngs
行,因为它无论如何都不会消耗任何 PRNG 密钥。)
为什么 Flax 中的 Block 模块不需要接收和返回额外的虚拟值:
jax.lax.scan
(API 文档) 要求其函数返回两个输入 - 进位和堆叠输出。在这种情况下,我们没有使用后者。Flax 对此进行了简化,因此如果你将out_axes
设置为nnx.Carry
而不是默认的(nnx.Carry, 0)
,现在可以选择忽略第二个输出。这是 Flax NNX 转换与 JAX 转换 API 不同的少数情况之一。
上面的 Flax 示例中有更多的代码行,但它们更精确地表达了每次发生的情况。由于 Flax 转换越来越接近 JAX 转换 API,因此建议在使用其 Flax NNX 等效项 之前,先很好地理解底层的 JAX 转换。
现在检查两侧的变量 pytree
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
顶层 Haiku 函数 vs 顶层 Flax 模块#
在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state}
来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”编写为函数是非常常见的做法。
Flax 团队推荐一种更以模块为中心的方法,该方法使用 __call__
来定义前向函数。在 Flax 模块中,可以使用常规 Python 类语义,像正常方式一样设置和访问参数和变量。
...
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter(
'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(jax.random.key(0), jnp.ones((1, 64)))
class Counter(nnx.Variable):
pass
class FooModule(nnx.Module):
def __init__(self, rngs):
self.counter = Counter(jnp.ones((), jnp.int32))
self.multiplier = nnx.Param(
nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
)
def __call__(self, x):
output = x + self.multiplier * self.counter.value
self.counter.value += 1
return output
model = FooModule(rngs=nnx.Rngs(0))
_, params, counter = nnx.split(model, nnx.Param, Counter)