从 Haiku 迁移到 Flax#

本指南演示了 Haiku 和 Flax NNX 模型之间的差异,并提供了并排的示例代码,以帮助您从 Haiku 迁移到 Flax NNX API。

如果您是 Flax NNX 的新手,请务必熟悉 Flax NNX 基础知识,其中涵盖了 nnx.Module 系统、Flax 转换功能性 API 以及示例。

让我们从一些导入开始。

基本模块定义#

Haiku 和 Flax 都使用 Module 类作为表达神经网络库层的默认单元。例如,要创建一个具有 dropout 和 ReLU 激活函数的单层网络,您需要

  • 首先,创建一个 Block(通过子类化 Module),它由一个带有 dropout 和 ReLU 激活函数的线性层组成。

  • 然后,在创建 Model(也通过子类化 Module)时,将 Block 用作子 Module,该 ModelBlock 和一个线性层组成。

Haiku 和 Flax Module 对象之间有两个根本区别

  • 无状态 vs. 有状态:

    • haiku.Module 实例是无状态的。这意味着,变量是从纯功能性 Module.init() 调用返回并单独管理的。

    • 然而,flax.nnx.Module 将其变量作为此 Python 对象的属性拥有。

  • 延迟 vs. 立即:

    • haiku.Module 仅在用户调用模型时实际看到输入时才分配空间来创建变量(延迟)。

    • flax.nnx.Module 实例在实例化时(在看到样本输入之前)创建变量(立即)。

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)


  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

变量创建#

本节介绍实例化模型和初始化其参数。

  • 要为 Haiku 模型生成模型参数,您需要将其放入前向函数中,并使用 haiku.transform 使其成为纯功能性。这会导致一个 JAX 数组jax.Array 数据类型)的嵌套字典,需要单独携带和维护。

  • 在 Flax NNX 中,当您实例化模型时,会自动初始化模型参数,并且变量(nnx.Variable 对象)作为属性存储在 nnx.Module(或其子模块)中。您仍然需要为它提供一个 伪随机数生成器 (PRNG) 密钥,但该密钥将包装在 nnx.Rngs 类中并存储在其中,并在需要时生成更多 PRNG 密钥。

如果您想以无状态的、类似字典的方式访问 Flax 模型参数以进行检查点保存或模型手术,请查看 Flax NNX 分割/合并 APInnx.split / nnx.merge)。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
sample_x = jnp.ones((1, 784))
params = model.init(jax.random.key(0), sample_x, training=False)


assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
...


model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

训练步骤和编译#

本节介绍如何编写训练步骤并使用 JAX 即时编译对其进行编译。

编译训练步骤时

  • Haiku 使用 @jax.jit - 一个 JAX 转换 - 来编译一个纯功能性训练步骤。

  • Flax NNX 使用 @nnx.jit - 一个 Flax NNX 转换(几个转换 API 之一,其行为类似于 JAX 转换,但也 可以很好地与 Flax 对象一起工作)。虽然 jax.jit 只接受具有纯无状态参数的函数,但 flax.nnx.jit 允许参数为有状态的模块。这大大减少了训练步骤所需的行数。

在求梯度时

  • 类似地,Haiku 使用 jax.grad(用于 自动微分的 JAX 转换)来返回梯度的原始字典。

  • 同时,Flax NNX 使用 flax.nnx.grad(Flax NNX 转换)来返回 Flax NNX 模块的梯度,作为 flax.nnx.State 字典。如果您想将常规 jax.grad 与 Flax NNX 一起使用,则需要使用 分割/合并 API

对于优化器

  • 如果您已经在使用 Optax 优化器(如 optax.adamw)(而不是此处显示的原始 jax.tree.map 计算)与 Haiku,请查看 Flax 基础知识指南中的 flax.nnx.Optimizer 示例,了解一种更简洁的训练和更新模型的方法。

每次训练步骤期间的模型更新

  • Haiku 训练步骤需要返回一个 JAX 树的参数作为下一步的输入。

  • Flax NNX 训练步骤不需要返回任何内容,因为 model 已经在 nnx.jit 内就地更新。

  • 此外,nnx.Module 对象是有状态的,并且 Module 会自动跟踪其中的几件事,例如 PRNG 密钥和 flax.nnx.BatchNorm 统计信息。这就是为什么您无需在每一步都显式传递 PRNG 密钥的原因。另请注意,您可以使用 flax.nnx.reseed 来重置其底层的 PRNG 状态。

dropout 行为

  • 在 Haiku 中,您需要显式定义并传入 training 参数来切换 haiku.dropout,并确保只有在 training=True 时才会发生随机 dropout。

  • 在 Flax NNX 中,你可以调用 model.train() (flax.nnx.Module.train()) 来自动将 flax.nnx.Dropout 切换到训练模式。相反,你可以调用 model.eval() (flax.nnx.Module.eval()) 来关闭训练模式。你可以在其 API 参考 中了解更多关于 flax.nnx.Module.train 的信息。

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      params, key,
      inputs, training=True # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)


  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
model.train() # set deterministic=False

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(

      inputs, # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.GraphState.merge(params, rest))

处理非参数状态#

Haiku 区分了可训练参数和模型跟踪的所有其他数据(“状态”)。例如,批量归一化中使用的批次统计信息被视为状态。具有状态的模型需要使用 hk.transform_with_state 进行转换,以便其 .init() 返回参数和状态。

在 Flax 中,没有如此严格的区分 - 它们都是 nnx.Variable 的子类,并被模块视为其属性。参数是名为 nnx.Param 的子类的实例,而批次统计信息可以是另一个名为 nnx.BatchStat 的子类。你可以使用 nnx.split 快速提取特定变量类型的所有数据。

让我们通过采用上面的 Block 定义,但将 dropout 替换为 BatchNorm 来看看这个例子。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features



  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x

def forward(x, training: bool):
  return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)

sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(jax.random.key(0), sample_x, training=True)
class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)


    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)

Flax 会考虑可训练参数和其他数据之间的差异。nnx.grad 只会对 nnx.Param 变量求梯度,从而自动跳过 batchnorm 数组。因此,对于使用此模型的 Flax NNX,训练步骤将看起来相同。

使用多个方法#

在本节中,你将学习如何在 Haiku 和 Flax 中使用多个方法。例如,你将实现一个具有三个方法的自动编码器模型:encodedecode__call__

在 Haiku 中,你需要使用 hk.multi_transform 显式定义如何初始化模型以及它可以调用哪些方法(这里的 encodedecode)。请注意,你仍然需要定义一个 __call__,它会激活两个层,以便延迟初始化所有模型参数。

在 Flax 中,它更简单,因为你在 __init__ 中初始化了参数,并且可以直接使用 nnx.Module 方法 encodedecode

class AutoEncoder(hk.Module):

  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
params = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):

  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):

    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)











model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
...

参数结构如下

...


{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
})

要调用这些自定义方法

  • 在 Haiku 中,你需要解耦 .apply 函数以在调用它之前提取你的方法。

  • 在 Flax 中,你可以直接调用该方法。

encode, decode = model.apply
z = encode(params, None, x=jnp.ones((1, 784)))
...
z = model.encode(jnp.ones((1, 784)))

转换#

Haiku 和 Flax 转换 都提供了自己的转换集,这些转换以一种可以与 Module 对象一起使用的方式包装了 JAX 转换

有关 Flax 转换的更多信息,请查看 转换指南

让我们从一个例子开始

  • 首先,定义一个 RNNCell Module,其中将包含 RNN 单步的逻辑。

  • 定义一个 initial_state 方法,该方法将用于初始化 RNN 的状态(又名 carry)。与 jax.lax.scan (API 文档) 一样,RNNCell.__call__ 方法将是一个接收 carry 和输入并返回新的 carry 和输出的函数。在这种情况下,carry 和输出是相同的。

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下来,我们将定义一个 RNN 模块,其中将包含整个 RNN 的逻辑。在这两种情况下,我们都使用库的 scan 调用来运行输入序列上的 RNNCell

唯一的区别是 Flax nnx.scan 允许你指定在参数 in_axesout_axes 中重复的轴,这将转发到底层的 `jax.lax.scan<https://jax.net.cn/en/latest/_autosummary/jax.lax.scan.html>`__,而在 Haiku 中,你需要显式地转置输入和输出。

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(
      cell, carry,
      jnp.swapaxes(x, 1, 0)
    )
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

扫描层#

大多数 Haiku 转换应与 Flax 类似,因为它们都包装了它们的 JAX 对等物,但扫描层用例是一个例外。

扫描层是一种技术,你可以在其中通过一系列 N 个重复层运行输入,将每个层的输出作为下一层的输入传递。这种模式可以显著减少大型模型的编译时间。在下面的示例中,你将在顶层的 MLP Module 中重复 Block Module 5 次。

在 Haiku 中,我们像往常一样定义 Block 模块,然后在 MLP 内部,我们将使用 hk.experimental.layer_stackstack_block 函数之上来创建 Block 模块的堆栈。相同的代码将在初始化时创建 5 层参数,并在调用时通过它们运行输入。

在 Flax 中,模型初始化和调用代码是完全解耦的,因此我们使用 nnx.vmap 转换来初始化底层 Block 参数,并使用 nnx.scan 转换来运行模型输入。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers





  def __call__(self, x, training: bool):

    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)

sample_x = jnp.ones((1, 64))
params = model.init(jax.random.key(0), sample_x, training=False)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)



model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

在上面的 Flax 示例中还有一些其他细节需要解释

  • `@nnx.split_rngs` 装饰器: Flax 转换与其 JAX 对等物一样,完全不了解 PRNG 状态,并且依赖于 PRNG 密钥的输入。nnx.split_rngs 装饰器允许你在将 nnx.Rngs 传递给装饰函数之前拆分它们,并在之后“降低”它们,以便它们可以在外部使用。

    • 在这里,你拆分了 PRNG 密钥,因为如果其每个内部操作都需要自己的密钥,则 jax.vmapjax.lax.scan 需要 PRNG 密钥列表。因此,对于 MLP 内的 5 层,你在转到 JAX 转换之前拆分并从其参数提供 5 个不同的 PRNG 密钥。

    • 请注意,实际上 create_block() 知道它需要创建 5 层正是因为它看到了 5 个 PRNG 密钥,因为 in_axes=(0,) 指示 vmap 将查看第一个参数的第一个维度以了解它将映射的大小。

    • 对于 forward() 也是如此,它会查看第一个参数(又名 model)中的变量,以找出它需要扫描多少次。nnx.split_rngs 这里实际上拆分了 model 内部的 PRNG 状态。(如果 Block Module 没有 dropout,则你不需要 nnx.split_rngs 行,因为它无论如何都不会消耗任何 PRNG 密钥。)

  • 为什么 Flax 中的 Block 模块不需要接收和返回额外的虚拟值: jax.lax.scan (API 文档) 要求其函数返回两个输入 - 进位和堆叠输出。在这种情况下,我们没有使用后者。Flax 对此进行了简化,因此如果你将 out_axes 设置为 nnx.Carry 而不是默认的 (nnx.Carry, 0),现在可以选择忽略第二个输出。

    • 这是 Flax NNX 转换与 JAX 转换 API 不同的少数情况之一。

上面的 Flax 示例中有更多的代码行,但它们更精确地表达了每次发生的情况。由于 Flax 转换越来越接近 JAX 转换 API,因此建议在使用其 Flax NNX 等效项 之前,先很好地理解底层的 JAX 转换

现在检查两侧的变量 pytree

...


{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}



...
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'blocks': {
    'linear': {
      'bias': VariableState(type=Param, value=(5, 64)),
      'kernel': VariableState(type=Param, value=(5, 64, 64))
    }
  }
})

顶层 Haiku 函数 vs 顶层 Flax 模块#

在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state} 来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”编写为函数是非常常见的做法。

Flax 团队推荐一种更以模块为中心的方法,该方法使用 __call__ 来定义前向函数。在 Flax 模块中,可以使用常规 Python 类语义,像正常方式一样设置和访问参数和变量。

...


def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter(
    'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
  )

  output = x + multiplier * counter

  hk.set_state("counter", counter + 1)
  return output

model = hk.transform_with_state(forward)

params, state = model.init(jax.random.key(0), jnp.ones((1, 64)))
class Counter(nnx.Variable):
  pass

class FooModule(nnx.Module):

  def __init__(self, rngs):
    self.counter = Counter(jnp.ones((), jnp.int32))
    self.multiplier = nnx.Param(
      nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
    )
  def __call__(self, x):
    output = x + self.multiplier * self.counter.value

    self.counter.value += 1
    return output

model = FooModule(rngs=nnx.Rngs(0))

_, params, counter = nnx.split(model, nnx.Param, Counter)