从 Haiku 迁移到 Flax#

本指南将逐步介绍将 Haiku 模型迁移到 Flax 的过程,并重点介绍这两个库之间的差异。

基本示例#

要创建自定义模块,您需要在 Haiku 和 Flax 中都从 Module 基类进行子类化。但是,Haiku 类使用常规的 __init__ 方法,而 Flax 类是 dataclasses,这意味着您定义了一些用于自动生成构造函数的类属性。此外,所有 Flax 模块都接受一个 name 参数,无需定义它,而在 Haiku 中,name 必须在构造函数签名中显式定义并传递给超类构造函数。

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x

__call__ 方法在这两个库中看起来非常相似,但是,在 Flax 中,您必须使用 @nn.compact 装饰器才能在内联定义子模块。在 Haiku 中,这是默认行为。

现在,Haiku 和 Flax 在构建模型方面存在很大差异。在 Haiku 中,您使用 hk.transform 对调用模块的函数进行转换,transform 将返回一个具有 initapply 方法的对象。在 Flax 中,您只需实例化您的模块即可。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
...


model = Model(256, 10)

要获取这两个库中的模型参数,请使用 init 方法,并使用 random.key 加上一些输入来运行模型。这里的主要区别是 Flax 返回一个从集合名称到嵌套数组字典的映射,params 只是这些可能的集合之一。在 Haiku 中,您可以直接获得 params 结构。

sample_x = jax.numpy.ones((1, 784))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params = variables["params"]

需要注意的一件非常重要的事情是,在 Flax 中,参数结构是分层的,每个嵌套模块有一层,最后一层用于参数名称。在 Haiku 中,参数结构是一个 Python 字典,具有两级层次结构:完全限定的模块名称映射到参数名称。模块名称由 / 分隔的字符串路径组成,该路径包含所有嵌套模块。

...
{
  'model/block/linear': {
    'b': (256,),
    'w': (784, 256),
  },
  'model/linear': {
    'b': (10,),
    'w': (256, 10),
  }
}
...
FrozenDict({
  Block_0: {
    Dense_0: {
      bias: (256,),
      kernel: (784, 256),
    },
  },
  Dense_0: {
    bias: (10,),
    kernel: (256, 10),
  },
})

在两个框架中的训练过程中,您将参数结构传递给 apply 方法以运行正向传递。由于我们使用的是丢弃,因此在这两种情况下,我们都必须向 apply 提供一个 key 以便生成随机丢弃掩码。

def train_step(key, params, inputs, labels):
  def loss_fn(params):
      logits = model.apply(
        params,
        key,
        inputs, training=True # <== inputs
      )
      return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
def train_step(key, params, inputs, labels):
  def loss_fn(params):
      logits = model.apply(
        {'params': params},
        inputs, training=True, # <== inputs
        rngs={'dropout': key}
      )
      return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params

最显著的差异是,在 Flax 中,您必须将参数放在带有 params 键的字典中,并将键放在带有 dropout 键的字典中。这是因为在 Flax 中,您可以拥有多种类型的模型状态和随机状态。在 Haiku 中,您只需直接传递参数和键。

处理状态#

现在让我们看看这两个库是如何处理可变状态的。我们将使用与之前相同的模型,但现在我们将用批量归一化替换丢弃。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x
class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.BatchNorm(
      momentum=0.99
    )(x, use_running_average=not training)
    x = jax.nn.relu(x)
    return x

在这种情况下,代码非常相似,因为两个库都提供了一个批量归一化层。最显著的差异是 Haiku 使用 is_training 来控制是否更新运行统计信息,而 Flax 使用 use_running_average 来实现相同目的。

要在 Haiku 中实例化有状态模型,您需要使用 hk.transform_with_state,它会更改 initapply 的签名以接受和返回状态。与之前一样,在 Flax 中,您直接构建模块。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform_with_state(forward)
...


model = Model(256, 10)

要初始化参数和状态,您只需像以前一样调用 init 方法即可。但是,在 Haiku 中,您现在会获得 state 作为第二个返回值,而在 Flax 中,您会在 variables 字典中获得一个新的 batch_stats 集合。请注意,由于 hk.BatchNorm 仅在 is_training=True 时初始化批量统计信息,因此我们在初始化具有 hk.BatchNorm 层的 Haiku 模型的参数时必须设置 training=True。在 Flax 中,我们可以像往常一样设置 training=False

sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
  random.key(0),
  sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]

通常,在 Flax 中,您可能会在 variables 字典中找到其他状态集合,例如 cache(用于自回归 Transformer 模型)、intermediates(用于使用 Module.sow 添加的中间值)或其他由自定义层定义的集合名称。Haiku 仅区分 params(在运行 apply 时不会改变的变量)和 state(在运行 apply 时可能会改变的变量),并使用 hk.transformhk.transform_with_state

现在,训练在这两个框架中看起来非常相似,因为您使用相同的 apply 方法来运行正向传递。在 Haiku 中,现在将 state 作为第二个参数传递给 apply,并将新状态作为第二个返回值获得。在 Flax 中,您改为将 batch_stats 作为新键添加到输入字典中,并将 updates 变量字典作为第二个返回值获得。

def train_step(params, state, inputs, labels):
  def loss_fn(params):
    logits, new_state = model.apply(
      params, state,
      None, # <== rng
      inputs, training=True # <== inputs
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, new_state

  grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, new_state
def train_step(params, batch_stats, inputs, labels):
  def loss_fn(params):
    logits, updates = model.apply(
      {'params': params, 'batch_stats': batch_stats},
      inputs, training=True, # <== inputs
      mutable='batch_stats',
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, updates["batch_stats"]

  grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, batch_stats

一个主要区别是,在 Flax 中,状态集合可以是可变的,也可以是不可变的。在 init 期间,默认情况下所有集合都是可变的,但是,在 apply 期间,您必须显式指定哪些集合是可变的。在本例中,我们指定 batch_stats 是可变的。这里传递的是单个字符串,但如果存在更多可变集合,则也可以传递列表。如果未执行此操作,则在尝试更改 batch_stats 时,将在运行时引发错误。此外,当 mutable 不是 False 时,updates 字典将作为 apply 的第二个返回值返回,否则仅返回模型输出。Haiku 通过使用 params(不可变)和 state(可变)以及使用 hk.transformhk.transform_with_state 来区分可变/不可变。

使用多个方法#

在本节中,我们将了解如何在 Haiku 和 Flax 中使用多个方法。例如,我们将实现一个具有三个方法的自编码器模型:encodedecode__call__

在 Haiku 中,我们只需直接在 __init__ 中定义 encodedecode 所需的子模块,在本例中,每个模块都将使用 Linear 层。在 Flax 中,我们将在 setup 中提前定义 encoderdecoder 模块,并在 encodedecode 中分别使用它们。

class AutoEncoder(hk.Module):


  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x
class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

请注意,在 Flax 中,setup 不会在 __init__ 之后运行,而是在调用 initapply 时运行。

现在,我们希望能够从我们的 AutoEncoder 模型中调用任何方法。在 Haiku 中,我们可以通过 hk.multi_transform 为模块定义多个 apply 方法。传递给 multi_transform 的函数定义了如何初始化模块以及要生成哪些不同的应用方法。

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
...




model = AutoEncoder(256, 784)

为了初始化模型的参数,可以使用 init 触发 __call__ 方法,该方法同时使用 encodedecode 方法。这将创建模型所需的所有参数。

params = model.init(
  random.key(0),
  x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
  random.key(0),
  x=jax.numpy.ones((1, 784)),
)
params = variables["params"]

这将生成以下参数结构。

{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
FrozenDict({
    decoder: {
        bias: (784,),
        kernel: (256, 784),
    },
    encoder: {
        bias: (256,),
        kernel: (784, 256),
    },
})

最后,让我们探索如何使用 apply 函数调用 encode 方法。

encode, decode = model.apply
z = encode(
  params,
  None, # <== rng
  x=jax.numpy.ones((1, 784)),

)
...
z = model.apply(
  {"params": params},

  x=jax.numpy.ones((1, 784)),
  method="encode",
)

由于 Haiku apply 函数是通过 hk.multi_transform 生成的,它是一个包含两个函数的元组,我们可以将其解包为一个 encode 函数和一个 decode 函数,它们对应于 AutoEncoder 模块上的方法。在 Flax 中,我们通过将方法名称作为字符串传递来调用 encode 方法。这里另一个值得注意的区别是,在 Haiku 中,即使模块在 apply 期间没有使用任何随机操作,也需要显式地传递 rng。在 Flax 中,这不是必需的(请查看 Flax 中的随机性和 PRNG)。这里的 Haiku rng 设置为 None,但你也可以在 apply 函数上使用 hk.without_apply_rng 来删除 rng 参数。

提升的转换#

Flax 和 Haiku 都提供了一组转换,我们将它们称为提升的转换,这些转换以一种可以与模块一起使用的方式包装 JAX 转换,并且有时会提供额外的功能。在本节中,我们将了解如何在 Flax 和 Haiku 中使用 scan 的提升版本来实现一个简单的 RNN 层。

首先,我们将定义一个 RNNCell 模块,该模块将包含 RNN 单步的逻辑。我们还将定义一个 initial_state 方法,该方法将用于初始化 RNN 的状态(也称为 carry)。与 jax.lax.scan 一样,RNNCell.__call__ 方法将是一个函数,它接收传递和输入,并返回新的传递和输出。在这种情况下,传递和输出是相同的。

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下来,我们将定义一个 RNN 模块,该模块将包含整个 RNN 的逻辑。在 Haiku 中,我们将首先初始化 RNNCell,然后使用它来构建 carry,最后使用 hk.scan 在输入序列上运行 RNNCell。在 Flax 中,它的做法略有不同,我们将使用 nn.scan 来定义一个新的临时类型,该类型包装 RNNCell。在此过程中,我们还将指定指示 nn.scan 广播 params 集合(所有步骤共享相同的参数)并且不拆分 params rng 流(以便所有步骤使用相同的参数进行初始化),最后,我们将指定我们希望 scan 在输入的第二个轴上运行,并将输出也沿着第二个轴堆叠起来。然后,我们将立即使用此临时类型来创建提升的 RNNCell 的实例,并使用它来创建 carry 并运行 __call__ 方法,该方法将在序列上进行 scan

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
                  in_axes=1, out_axes=1)(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)
    return y

总的来说,Flax 和 Haiku 之间提升的转换的主要区别在于,在 Haiku 中,提升的转换不会对状态进行操作,也就是说,Haiku 将以一种在转换内外保持相同形状的方式处理 paramsstate。在 Flax 中,提升的转换可以对变量集合和 rng 流进行操作,用户必须根据转换的语义定义每个转换如何处理不同的集合。

最后,让我们快速查看如何在 Haiku 和 Flax 中使用 RNN 模块。

def forward(x):
  return RNN(64)(x)

model = hk.without_apply_rng(hk.transform(forward))

params = model.init(
  random.key(0),
  x=jax.numpy.ones((3, 12, 32)),
)

y = model.apply(
  params,
  x=jax.numpy.ones((3, 12, 32)),
)
...


model = RNN(64)

variables = model.init(
  random.key(0),
  x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
  {'params': params},
  x=jax.numpy.ones((3, 12, 32)),
)

与前面部分中的示例相比,唯一值得注意的变化是这次我们在 Haiku 中使用了 hk.without_apply_rng,因此我们不必将 rng 参数作为 None 传递给 apply 方法。

在层上进行扫描#

scan 的一个非常重要的应用是,迭代地在输入上应用一系列层,将每个层的输出作为下一个层的输入传递。这对于减少大型模型的编译时间非常有用。例如,我们将创建一个简单的 Block 模块,然后将其用在 MLP 模块中,该模块将应用 num_layers 次的 Block 模块。

在 Haiku 中,我们像往常一样定义 Block 模块,然后在 MLP 中,我们将在 stack_block 函数上使用 hk.experimental.layer_stack 来创建一个 Block 模块堆栈。在 Flax 中,Block 的定义略有不同,__call__ 将接收和返回一个第二个虚拟输入/输出,它们在两种情况下都将为 None。在 MLP 中,我们将像在前面的示例中一样使用 nn.scan,但通过设置 split_rngs={'params': True}variable_axes={'params': 0},我们告诉 nn.scan 为每个步骤创建不同的参数并沿着第一个轴切片 params 集合,从而有效地实现一个 Block 模块堆栈,就像在 Haiku 中一样。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers

  def __call__(self, x, training: bool):
    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)
class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int

  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y

请注意,在 Flax 中,我们如何将 None 作为第二个参数传递给 ScanBlock 并忽略其第二个输出。这些表示每个步骤的输入/输出,但它们是 None,因为在这种情况下,我们没有任何输入/输出。

初始化每个模型与前面的示例相同。在这种情况下,我们将指定我们希望使用 5 个层,每个层具有 64 个特征。

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)

model = hk.transform(forward)

sample_x = jax.numpy.ones((1, 64))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
...
...


model = MLP(64, num_layers=5)

sample_x = jax.numpy.ones((1, 64))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params = variables['params']

当在层上使用 scan 时,你应该注意的是,所有层都融合成一个单独的层,其参数在第一个轴上有一个额外的“层”维度。在这种情况下,所有参数的形状都将以 (5, ...) 开头,因为我们使用的是 5 个层。

...
{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}
...
FrozenDict({
    ScanBlock_0: {
        Dense_0: {
            bias: (5, 64),
            kernel: (5, 64, 64),
        },
    },
})

顶级 Haiku 函数与顶级 Flax 模块#

在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state} 来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶级“模块”编写为函数非常常见。

Flax 团队建议使用更以模块为中心的 approach,该 approach 使用 __call__ 来定义前向函数。相应的访问器将是 nn.module.paramnn.module.variable(有关集合的解释,请转到 处理状态)。

def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
  output = x + multiplier * counter
  hk.set_state("counter", counter + 1)

  return output

model = hk.transform_with_state(forward)

params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
    multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
    output = x + multiplier * counter.value
    if not self.is_initializing():  # otherwise model.init() also increases it
      counter.value += 1
    return output

model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']