从 Flax Linen 到 NNX 的演变#

本指南演示了 Flax Linen 和 Flax NNX 模型之间的差异,提供并排示例代码以帮助您从 Flax Linen 迁移到 Flax NNX API。

本文档主要介绍如何将任意 Flax Linen 代码转换为 Flax NNX。如果您想“安全地”逐步转换您的代码库,请查看 通过 nnx.bridge 一起使用 Flax NNX 和 Linen 指南。

为了充分利用本指南,强烈建议您阅读 Flax NNX 基础知识文档,该文档涵盖了 nnx.Module 系统、Flax 转换以及带有示例的函数式 API

基本 Module 定义#

Flax Linen 和 Flax NNX 都使用 Module 类作为表达神经网络库层的默认单元。在下面的示例中,您首先创建一个 Block(通过子类化 Module),它由一个具有 dropout 和 ReLU 激活函数的线性层组成;然后,在创建 Model(也通过子类化 Module)时,您将其用作子-ModuleModelBlock 和一个线性层组成。

Flax Linen 和 Flax NNX Module 对象之间有两个根本区别

  • 无状态 vs. 有状态flax.linen.Module (nn.Module) 实例是无状态的 - 变量是从纯函数式 Module.init() 调用返回并单独管理的。flax.nnx.Module,但是,拥有其变量作为此 Python 对象的属性。

  • 惰性 vs. 迫切flax.linen.Module 仅在实际看到其输入(惰性)时才分配空间来创建变量。flax.nnx.Module 实例会在实例化时(在看到样本输入之前)创建变量(迫切)。

  • Flax Linen 可以使用 @nn.compact 装饰器在单个方法中定义模型,并使用输入样本进行形状推断。Flax NNX Module 通常需要额外的形状信息才能在 __init__ 期间创建所有参数,并在 __call__ 方法中单独定义计算。

import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

变量创建#

接下来,让我们讨论实例化模型和初始化其参数

  • 要为 Flax Linen 模型生成模型参数,您需要使用 jax.random.key (文档) 和模型应采用的一些样本输入来调用 flax.linen.Module.init (nn.Module.init) 方法。这将产生一个嵌套的 JAX 数组jax.Array 数据类型)字典,这些字典将被随身携带并单独维护。

  • 在 Flax NNX 中,模型参数在您实例化模型时自动初始化,变量(nnx.Variable 对象)作为属性存储在 nnx.Module(或其子-Module)中。您仍然需要为其提供一个 伪随机数生成器 (PRNG) 密钥,但该密钥将包装在 nnx.Rngs 类中并存储在内部,并在需要时生成更多的 PRNG 密钥。

如果您想以无状态的、类似字典的方式访问 Flax NNX 模型参数以进行检查点保存或模型手术,请查看 Flax NNX 分割/合并 APInnx.split / nnx.merge)。

model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]

assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

训练步骤和编译#

现在,让我们继续编写训练步骤并使用 JAX 即时编译进行编译。以下是 Flax Linen 和 Flax NNX 方法之间的一些差异。

编译训练步骤

  • Flax Linen 使用 @jax.jit - 一个 JAX 转换 - 来编译训练步骤。

  • Flax NNX 使用 @nnx.jit - 一个 Flax NNX 转换(几个与 JAX 转换行为类似的转换 API 之一,但也与 Flax NNX 对象配合良好)。因此,虽然 jax.jit 仅接受纯无状态参数的函数,但 nnx.jit 允许参数是有状态的 NNX 模块。这大大减少了训练步骤所需的代码行数。

获取梯度

  • 同样,Flax Linen 使用 jax.grad(用于自动微分的 JAX 转换)来返回梯度的原始字典。

  • Flax NNX 使用 nnx.grad(Flax NNX 转换)将 NNX 模块的梯度作为 nnx.State 字典返回。如果您想将常规的 jax.grad 与 Flax NNX 一起使用,则需要使用Flax NNX 分割/合并 API

优化器

  • 如果您已经在使用 Optax 优化器(如 optax.adamw)(而不是此处显示的原始 jax.tree.map 计算)与 Flax Linen,请查看 Flax NNX 基础知识指南中的 nnx.Optimizer 示例,以获得一种更简洁的模型训练和更新方法。

每个训练步骤中的模型更新

  • Flax Linen 训练步骤需要返回一个参数的pytree作为下一步的输入。

  • Flax NNX 训练步骤不需要返回任何内容,因为 model 已经在 nnx.jit 内就地更新了。

  • 此外,nnx.Module 对象是有状态的,并且 Module 会自动跟踪其中的几个内容,例如 PRNG 密钥和 BatchNorm 统计信息。这就是为什么您不需要在每一步中显式传入 PRNG 密钥的原因。另请注意,您可以使用 nnx.reseed 来重置其底层的 PRNG 状态。

Dropout 行为

  • 在 Flax Linen 中,您需要显式定义并传入 training 参数来控制 flax.linen.Dropout (nn.Dropout) 的行为,即其 deterministic 标志。这意味着只有当 training=True 时才会发生随机 dropout。

  • 在 Flax NNX 中,您可以调用 model.train() (flax.nnx.Module.train()) 来自动将 nnx.Dropout 切换到训练模式。相反,您可以调用 model.eval() (flax.nnx.Module.eval()) 来关闭训练模式。您可以在其 API 参考中了解更多关于 nnx.Module.train 的作用。

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      {'params': params},
      inputs, training=True, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)

  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(inputs)




    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.GraphState.merge(params, rest))

集合和变量类型#

Flax Linen 和 NNX API 之间的一个主要区别是它们如何将变量分组到不同的类别。Flax Linen 使用不同的集合,而 Flax NNX,由于所有变量都应该是顶级的 Python 属性,您可以使用不同的变量类型。

在 Flax NNX 中,您可以自由创建自己的变量类型作为 nnx.Variable 的子类。

对于所有内置的 Flax Linen 层和集合,Flax NNX 已经创建了相应的层和变量类型。例如

  • flax.linen.Dense (nn.Dense) 创建 params -> nnx.Linear 创建 :class:nnx.Param<flax.nnx.Param>`。

  • flax.linen.BatchNorm (nn.BatchNorm) 创建 batch_stats -> nnx.BatchNorm 创建 nnx.BatchStats

  • flax.linen.Module.sow() 创建 intermediates -> nnx.Module.sow() 创建 nnx.Intermediaries

  • 在 Flax NNX 中,您还可以通过将其分配给 nnx.Module 属性来简单地获取中间变量 - 例如,self.sowed = nnx.Intermediates(x)。这类似于 Flax Linen 的 self.variable('intermediates' 'sowed', lambda: x)

class Block(nn.Module):
  features: int
  def setup(self):
    self.dense = nn.Dense(self.features)
    self.batchnorm = nn.BatchNorm(momentum=0.99)
    self.count = self.variable('counter', 'count',
                                lambda: jnp.zeros((), jnp.int32))


  @nn.compact
  def __call__(self, x, training: bool):
    x = self.dense(x)
    x = self.batchnorm(x, use_running_average=not training)
    self.count.value += 1
    x = jax.nn.relu(x)
    return x

x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape         # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape  # (4, )
variables['counter']['count']                        # 1
class Counter(nnx.Variable): pass

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )
    self.count = Counter(jnp.array(0))

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)
    self.count += 1
    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)
model.count           # Counter(value=...)

如果您想从变量的 pytree 中提取某些数组

  • 在 Flax Linen 中,您可以访问特定的字典路径。

  • 在 Flax NNX 中,您可以使用 nnx.split 来区分 Flax NNX 中的类型。下面的代码是一个简单的示例,它按类型拆分变量 - 请查看 Flax NNX 过滤器指南,了解更复杂的过滤表达式。

params, batch_stats, counter = (
  variables['params'], variables['batch_stats'], variables['counter'])
params.keys()       # ['dense', 'batchnorm']
batch_stats.keys()  # ['batchnorm']
counter.keys()      # ['count']

# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
  model, nnx.Param, nnx.BatchStat, Counter)
params.keys()       # ['batchnorm', 'linear']
batch_stats.keys()  # ['batchnorm']
count.keys()        # ['count']

# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)

使用多种方法#

在本节中,您将学习如何在 Flax Linen 和 Flax NNX 中使用多种方法。例如,您将实现一个具有三种方法的自动编码器模型:encodedecode__call__

定义编码器和解码器层

  • 在 Flax Linen 中,和之前一样,定义层时无需传入输入形状,因为 flax.linen.Module 参数将使用形状推断进行延迟初始化。

  • 在 Flax NNX 中,您必须传入输入形状,因为 nnx.Module 参数将进行主动初始化,而无需形状推断。

class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):



  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))

变量结构如下

# variables['params']
{
  decoder: {
      bias: (784,),
      kernel: (256, 784),
  },
  encoder: {
      bias: (256,),
      kernel: (784, 256),
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
})

调用 __call__ 以外的方法

  • 在 Flax Linen 中,您仍然需要使用 apply API。

  • 在 Flax NNX 中,您可以直接调用该方法。

z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))

变换#

Flax Linen 和 Flax NNX 变换都提供自己的一组变换,它们以可以与 Module 对象一起使用的方式包装 JAX 变换

Flax Linen 中的大多数变换,例如 gradjit,在 Flax NNX 中变化不大。但是,例如,如果您尝试对层进行 scan,如下节所述,代码会有很大不同。

让我们从一个例子开始

  • 首先,定义一个 RNNCell Module,它将包含 RNN 单步的逻辑。

  • 定义一个 initial_state 方法,该方法将用于初始化 RNN 的状态(又名 carry)。与 jax.lax.scan (API 文档) 一样,RNNCell.__call__ 方法将是一个接受 carry 和输入,并返回新 carry 和输出的函数。在这种情况下,carry 和输出是相同的。

class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下来,定义一个 RNN Module,它将包含整个 RNN 的逻辑。

在 Flax Linen 中

  • 您将使用 flax.linen.scan (nn.scan) 来定义一个新的临时类型,该类型包装 RNNCell。在此过程中,您还将:1) 指示 nn.scan 广播 params 集合(所有步骤共享相同的参数)并且不拆分 params PRNG 流(以便所有步骤都使用相同的参数进行初始化);最后,2) 指定您希望 scan 在输入的第二个轴上运行并将输出也沿第二个轴堆叠。

  • 然后,您将立即使用此临时类型来创建“提升”的 RNNCell 的实例,并使用它来创建 carry,并运行 __call__ 方法,该方法将 scan 整个序列。

在 Flax NNX 中

  • 您将创建一个 scan 函数 (scan_fn),它将使用 __init__ 中定义的 RNNCell 来扫描序列,并显式设置 in_axes=(nnx.Carry, None, 1)nnx.Carry 表示 carry 参数将是 carry,None 表示 cell 将广播到所有步骤,1 表示 x 将在轴 1 上扫描。

class RNN(nn.Module):
  hidden_size: int

  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(
      RNNCell, variable_broadcast='params',
      split_rngs={'params': False}, in_axes=1, out_axes=1
    )(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))

y = model(x)

扫描层#

一般来说,Flax Linen 和 Flax NNX 的变换应该看起来相同。然而,Flax NNX 变换旨在更接近其较低级别的 JAX 对应项,因此我们在某些 Linen 提升变换中放弃了一些假设。这种扫描层的使用案例将是展示它的一个很好的例子。

扫描层是一种技术,您可以通过一系列 N 个重复的层来运行输入,将每一层的输出作为下一层的输入。这种模式可以显著减少大型模型的编译时间。在下面的示例中,您将在顶层 MLP Module 中重复 Block Module 5 次。

  • 在 Flax Linen 中,您将 flax.linen.scan (nn.scan) 变换应用于 Block nn.Module,以创建一个更大的 ScanBlock nn.Module,其中包含 5 个 Block nn.Module 对象。它将在初始化时自动创建一个形状为 (5, 64, 64) 的大型参数,并在每次调用时迭代每个 (64, 64) 切片,总共 5 次,就像 jax.lax.scan (API 文档) 一样。

  • 仔细观察,在这个模型的逻辑中,实际上在初始化时不需要 jax.lax.scan 操作。那里发生的事情更像是 jax.vmap 操作 —— 你得到一个接受 (in_dim, out_dim)Block 子-Module,然后你将其 “vmap” num_layers 次,以创建一个更大的数组。

  • 在 Flax NNX 中,你可以利用模型初始化和运行代码完全解耦的事实,而是使用 nnx.vmap 变换来初始化底层的 Block 参数,并使用 nnx.scan 变换来运行模型输入。

有关 Flax NNX 变换的更多信息,请查看 变换指南

class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int




  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y

model = MLP(64, num_layers=5)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)

model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

Flax NNX 上面的示例中还有一些其他细节需要解释。

  • `@nnx.split_rngs` 装饰器:Flax NNX 变换完全不考虑 PRNG 状态,这使得它们的行为更像 JAX 变换,但与处理 PRNG 状态的 Flax Linen 变换不同。为了重新获得此功能,nnx.split_rngs 装饰器允许你在将 nnx.Rngs 传递给装饰函数之前对其进行拆分,然后在之后 “降低” 它们,以便可以在外部使用它们。

    • 这里,你拆分 PRNG 密钥是因为如果其内部的每个操作都需要自己的密钥,则 jax.vmapjax.lax.scan 需要一个 PRNG 密钥列表。因此,对于 MLP 内的 5 层,你在进入 JAX 变换之前,从其参数中拆分并提供 5 个不同的 PRNG 密钥。

    • 请注意,实际上 create_block() 正是因为看到了 5 个 PRNG 密钥才意识到它需要创建 5 层,因为 in_axes=(0,) 表示 vmap 将查看第一个参数的第一个维度,以了解它将映射的大小。

    • 对于 forward() 也是如此,它查看第一个参数(即 model)中的变量,以找出需要扫描多少次。nnx.split_rngs 在这里实际上拆分了 model 内的 PRNG 状态。(如果 Block Module 没有 dropout,则你不需要 nnx.split_rngs 行,因为它不会消耗任何 PRNG 密钥。)

  • 为什么 Flax NNX 中的 Block Module 不需要获取和返回额外的虚拟值:这是 jax.lax.scan 的一个要求 (API 文档。Flax NNX 简化了这一点,因此,如果将 out_axes 设置为 nnx.Carry 而不是默认的 (nnx.Carry, 0),你现在可以选择忽略第二个输出。

    • 这是 Flax NNX 变换与 JAX 变换 API 不同的少数情况之一。

Flax NNX 上面的示例中有更多的代码行,但它们更精确地表达了每次发生的事情。由于 Flax NNX 变换变得更接近 JAX 变换 API,因此建议在使用其 Flax NNX 等效项之前,对底层的 JAX 变换 有很好的理解。

现在检查两边的变量 pytree。

# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
  ScanBlock_0: {
    Dense_0: {
      bias: (5, 64),
      kernel: (5, 64, 64),
    },
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
  'blocks': {
    'linear': {
      'bias': VariableState(type=Param, value=(5, 64)),
      'kernel': VariableState(type=Param, value=(5, 64, 64))
    }
  }
})

在 Flax NNX 中使用 TrainState#

Flax Linen 有一个方便的 TrainState 数据类来捆绑模型、参数和优化器。在 Flax NNX 中,这并不是真正必要的。在本节中,你将学习如何围绕 TrainState 构建你的 Flax NNX 代码,以满足任何向后兼容性需求。

在 Flax NNX 中

  • 你必须首先在模型上调用 nnx.split,以获得单独的 nnx.GraphDefnnx.State 对象。

  • 你可以传入 nnx.Param 以将所有可训练参数过滤到单个 nnx.State 中,并为其余变量传入 ...

  • 你还需要子类化 TrainState 以添加其他变量的字段。

  • 然后,你可以传入 nnx.GraphDef.apply 作为 apply 函数,nnx.State 作为参数和其他变量,以及一个优化器作为 TrainState 构造函数的参数。

请注意,nnx.GraphDef.apply 将接受 nnx.State 对象作为参数,并返回一个可调用函数。可以对输入调用此函数以输出模型的 logits,以及更新后的 nnx.GraphDefnnx.State 对象。请注意下面使用了 @jax.jit,因为你没有将 Flax NNX 模块传递到 train_step 中。

from flax.training import train_state

sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']




state = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,

  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(key, state, inputs, labels):
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      inputs, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params)


  state = state.apply_gradients(grads=grads)

  return state
from flax.training import train_state

model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

class TrainState(train_state.TrainState):
  other_variables: nnx.State

state = TrainState.create(
  apply_fn=graphdef.apply,
  params=params,
  other_variables=other_variables,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, other_variables):
    logits, (graphdef, new_state) = state.apply_fn(
      params,
      other_variables

    )(inputs) # <== inputs
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params, state.other_variables)


  state = state.apply_gradients(grads=grads)

  return state