从 Flax Linen 到 NNX 的演变#
本指南演示了 Flax Linen 和 Flax NNX 模型之间的差异,提供并排示例代码以帮助您从 Flax Linen 迁移到 Flax NNX API。
本文档主要介绍如何将任意 Flax Linen 代码转换为 Flax NNX。如果您想“安全地”逐步转换您的代码库,请查看 通过 nnx.bridge 一起使用 Flax NNX 和 Linen 指南。
为了充分利用本指南,强烈建议您阅读 Flax NNX 基础知识文档,该文档涵盖了 nnx.Module
系统、Flax 转换以及带有示例的函数式 API。
基本 Module
定义#
Flax Linen 和 Flax NNX 都使用 Module
类作为表达神经网络库层的默认单元。在下面的示例中,您首先创建一个 Block
(通过子类化 Module
),它由一个具有 dropout 和 ReLU 激活函数的线性层组成;然后,在创建 Model
(也通过子类化 Module
)时,您将其用作子-Module
,Model
由 Block
和一个线性层组成。
Flax Linen 和 Flax NNX Module
对象之间有两个根本区别
无状态 vs. 有状态:
flax.linen.Module
(nn.Module
) 实例是无状态的 - 变量是从纯函数式Module.init()
调用返回并单独管理的。flax.nnx.Module
,但是,拥有其变量作为此 Python 对象的属性。惰性 vs. 迫切:
flax.linen.Module
仅在实际看到其输入(惰性)时才分配空间来创建变量。flax.nnx.Module
实例会在实例化时(在看到样本输入之前)创建变量(迫切)。Flax Linen 可以使用
@nn.compact
装饰器在单个方法中定义模型,并使用输入样本进行形状推断。Flax NNXModule
通常需要额外的形状信息才能在__init__
期间创建所有参数,并在__call__
方法中单独定义计算。
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
变量创建#
接下来,让我们讨论实例化模型和初始化其参数
要为 Flax Linen 模型生成模型参数,您需要使用
jax.random.key
(文档) 和模型应采用的一些样本输入来调用flax.linen.Module.init
(nn.Module.init
) 方法。这将产生一个嵌套的 JAX 数组(jax.Array
数据类型)字典,这些字典将被随身携带并单独维护。在 Flax NNX 中,模型参数在您实例化模型时自动初始化,变量(
nnx.Variable
对象)作为属性存储在nnx.Module
(或其子-Module
)中。您仍然需要为其提供一个 伪随机数生成器 (PRNG) 密钥,但该密钥将包装在nnx.Rngs
类中并存储在内部,并在需要时生成更多的 PRNG 密钥。
如果您想以无状态的、类似字典的方式访问 Flax NNX 模型参数以进行检查点保存或模型手术,请查看 Flax NNX 分割/合并 API(nnx.split
/ nnx.merge
)。
model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]
assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# Parameters were already initialized during model instantiation.
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
训练步骤和编译#
现在,让我们继续编写训练步骤并使用 JAX 即时编译进行编译。以下是 Flax Linen 和 Flax NNX 方法之间的一些差异。
编译训练步骤
Flax Linen 使用
@jax.jit
- 一个 JAX 转换 - 来编译训练步骤。Flax NNX 使用
@nnx.jit
- 一个 Flax NNX 转换(几个与 JAX 转换行为类似的转换 API 之一,但也与 Flax NNX 对象配合良好)。因此,虽然jax.jit
仅接受纯无状态参数的函数,但nnx.jit
允许参数是有状态的 NNX 模块。这大大减少了训练步骤所需的代码行数。
获取梯度
同样,Flax Linen 使用
jax.grad
(用于自动微分的 JAX 转换)来返回梯度的原始字典。Flax NNX 使用
nnx.grad
(Flax NNX 转换)将 NNX 模块的梯度作为nnx.State
字典返回。如果您想将常规的jax.grad
与 Flax NNX 一起使用,则需要使用Flax NNX 分割/合并 API。
优化器
如果您已经在使用 Optax 优化器(如
optax.adamw
)(而不是此处显示的原始jax.tree.map
计算)与 Flax Linen,请查看 Flax NNX 基础知识指南中的nnx.Optimizer
示例,以获得一种更简洁的模型训练和更新方法。
每个训练步骤中的模型更新
Flax Linen 训练步骤需要返回一个参数的pytree作为下一步的输入。
Flax NNX 训练步骤不需要返回任何内容,因为
model
已经在nnx.jit
内就地更新了。此外,
nnx.Module
对象是有状态的,并且Module
会自动跟踪其中的几个内容,例如 PRNG 密钥和BatchNorm
统计信息。这就是为什么您不需要在每一步中显式传入 PRNG 密钥的原因。另请注意,您可以使用nnx.reseed
来重置其底层的 PRNG 状态。
Dropout 行为
在 Flax Linen 中,您需要显式定义并传入
training
参数来控制flax.linen.Dropout
(nn.Dropout
) 的行为,即其deterministic
标志。这意味着只有当training=True
时才会发生随机 dropout。在 Flax NNX 中,您可以调用
model.train()
(flax.nnx.Module.train()
) 来自动将nnx.Dropout
切换到训练模式。相反,您可以调用model.eval()
(flax.nnx.Module.eval()
) 来关闭训练模式。您可以在其 API 参考中了解更多关于nnx.Module.train
的作用。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(inputs)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
集合和变量类型#
Flax Linen 和 NNX API 之间的一个主要区别是它们如何将变量分组到不同的类别。Flax Linen 使用不同的集合,而 Flax NNX,由于所有变量都应该是顶级的 Python 属性,您可以使用不同的变量类型。
在 Flax NNX 中,您可以自由创建自己的变量类型作为 nnx.Variable
的子类。
对于所有内置的 Flax Linen 层和集合,Flax NNX 已经创建了相应的层和变量类型。例如
flax.linen.Dense
(nn.Dense
) 创建params
->nnx.Linear
创建 :class:nnx.Param<flax.nnx.Param>`。flax.linen.BatchNorm
(nn.BatchNorm
) 创建batch_stats
->nnx.BatchNorm
创建nnx.BatchStats
。flax.linen.Module.sow()
创建intermediates
->nnx.Module.sow()
创建nnx.Intermediaries
。在 Flax NNX 中,您还可以通过将其分配给
nnx.Module
属性来简单地获取中间变量 - 例如,self.sowed = nnx.Intermediates(x)
。这类似于 Flax Linen 的self.variable('intermediates' 'sowed', lambda: x)
。
class Block(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
self.batchnorm = nn.BatchNorm(momentum=0.99)
self.count = self.variable('counter', 'count',
lambda: jnp.zeros((), jnp.int32))
@nn.compact
def __call__(self, x, training: bool):
x = self.dense(x)
x = self.batchnorm(x, use_running_average=not training)
self.count.value += 1
x = jax.nn.relu(x)
return x
x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape # (4, )
variables['counter']['count'] # 1
class Counter(nnx.Variable): pass
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
self.count = Counter(jnp.array(0))
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
self.count += 1
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
model.count # Counter(value=...)
如果您想从变量的 pytree 中提取某些数组
在 Flax Linen 中,您可以访问特定的字典路径。
在 Flax NNX 中,您可以使用
nnx.split
来区分 Flax NNX 中的类型。下面的代码是一个简单的示例,它按类型拆分变量 - 请查看 Flax NNX 过滤器指南,了解更复杂的过滤表达式。
params, batch_stats, counter = (
variables['params'], variables['batch_stats'], variables['counter'])
params.keys() # ['dense', 'batchnorm']
batch_stats.keys() # ['batchnorm']
counter.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
model, nnx.Param, nnx.BatchStat, Counter)
params.keys() # ['batchnorm', 'linear']
batch_stats.keys() # ['batchnorm']
count.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)
使用多种方法#
在本节中,您将学习如何在 Flax Linen 和 Flax NNX 中使用多种方法。例如,您将实现一个具有三种方法的自动编码器模型:encode
、decode
和 __call__
。
定义编码器和解码器层
在 Flax Linen 中,和之前一样,定义层时无需传入输入形状,因为
flax.linen.Module
参数将使用形状推断进行延迟初始化。在 Flax NNX 中,您必须传入输入形状,因为
nnx.Module
参数将进行主动初始化,而无需形状推断。
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
变量结构如下
# variables['params']
{
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
},
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
调用 __call__
以外的方法
在 Flax Linen 中,您仍然需要使用
apply
API。在 Flax NNX 中,您可以直接调用该方法。
z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))
变换#
Flax Linen 和 Flax NNX 变换都提供自己的一组变换,它们以可以与 Module
对象一起使用的方式包装 JAX 变换。
Flax Linen 中的大多数变换,例如 grad
或 jit
,在 Flax NNX 中变化不大。但是,例如,如果您尝试对层进行 scan
,如下节所述,代码会有很大不同。
让我们从一个例子开始
首先,定义一个
RNNCell
Module
,它将包含 RNN 单步的逻辑。定义一个
initial_state
方法,该方法将用于初始化 RNN 的状态(又名carry
)。与jax.lax.scan
(API 文档) 一样,RNNCell.__call__
方法将是一个接受 carry 和输入,并返回新 carry 和输出的函数。在这种情况下,carry 和输出是相同的。
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,定义一个 RNN
Module
,它将包含整个 RNN 的逻辑。
在 Flax Linen 中
您将使用
flax.linen.scan
(nn.scan
) 来定义一个新的临时类型,该类型包装RNNCell
。在此过程中,您还将:1) 指示nn.scan
广播params
集合(所有步骤共享相同的参数)并且不拆分params
PRNG 流(以便所有步骤都使用相同的参数进行初始化);最后,2) 指定您希望 scan 在输入的第二个轴上运行并将输出也沿第二个轴堆叠。然后,您将立即使用此临时类型来创建“提升”的
RNNCell
的实例,并使用它来创建carry
,并运行__call__
方法,该方法将scan
整个序列。
在 Flax NNX 中
您将创建一个
scan
函数 (scan_fn
),它将使用__init__
中定义的RNNCell
来扫描序列,并显式设置in_axes=(nnx.Carry, None, 1)
。nnx.Carry
表示carry
参数将是 carry,None
表示cell
将广播到所有步骤,1
表示x
将在轴 1 上扫描。
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(
RNNCell, variable_broadcast='params',
split_rngs={'params': False}, in_axes=1, out_axes=1
)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))
y = model(x)
扫描层#
一般来说,Flax Linen 和 Flax NNX 的变换应该看起来相同。然而,Flax NNX 变换旨在更接近其较低级别的 JAX 对应项,因此我们在某些 Linen 提升变换中放弃了一些假设。这种扫描层的使用案例将是展示它的一个很好的例子。
扫描层是一种技术,您可以通过一系列 N 个重复的层来运行输入,将每一层的输出作为下一层的输入。这种模式可以显著减少大型模型的编译时间。在下面的示例中,您将在顶层 MLP
Module
中重复 Block
Module
5 次。
在 Flax Linen 中,您将
flax.linen.scan
(nn.scan
) 变换应用于Block
nn.Module
,以创建一个更大的ScanBlock
nn.Module
,其中包含 5 个Block
nn.Module
对象。它将在初始化时自动创建一个形状为(5, 64, 64)
的大型参数,并在每次调用时迭代每个(64, 64)
切片,总共 5 次,就像jax.lax.scan
(API 文档) 一样。仔细观察,在这个模型的逻辑中,实际上在初始化时不需要
jax.lax.scan
操作。那里发生的事情更像是jax.vmap
操作 —— 你得到一个接受(in_dim, out_dim)
的Block
子-Module
,然后你将其 “vmap”num_layers
次,以创建一个更大的数组。在 Flax NNX 中,你可以利用模型初始化和运行代码完全解耦的事实,而是使用
nnx.vmap
变换来初始化底层的Block
参数,并使用nnx.scan
变换来运行模型输入。
有关 Flax NNX 变换的更多信息,请查看 变换指南。
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
model = MLP(64, num_layers=5)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
Flax NNX 上面的示例中还有一些其他细节需要解释。
`@nnx.split_rngs` 装饰器:Flax NNX 变换完全不考虑 PRNG 状态,这使得它们的行为更像 JAX 变换,但与处理 PRNG 状态的 Flax Linen 变换不同。为了重新获得此功能,
nnx.split_rngs
装饰器允许你在将nnx.Rngs
传递给装饰函数之前对其进行拆分,然后在之后 “降低” 它们,以便可以在外部使用它们。这里,你拆分 PRNG 密钥是因为如果其内部的每个操作都需要自己的密钥,则
jax.vmap
和jax.lax.scan
需要一个 PRNG 密钥列表。因此,对于MLP
内的 5 层,你在进入 JAX 变换之前,从其参数中拆分并提供 5 个不同的 PRNG 密钥。请注意,实际上
create_block()
正是因为看到了 5 个 PRNG 密钥才意识到它需要创建 5 层,因为in_axes=(0,)
表示vmap
将查看第一个参数的第一个维度,以了解它将映射的大小。对于
forward()
也是如此,它查看第一个参数(即model
)中的变量,以找出需要扫描多少次。nnx.split_rngs
在这里实际上拆分了model
内的 PRNG 状态。(如果Block
Module
没有 dropout,则你不需要nnx.split_rngs
行,因为它不会消耗任何 PRNG 密钥。)
为什么 Flax NNX 中的 Block Module 不需要获取和返回额外的虚拟值:这是
jax.lax.scan
的一个要求 (API 文档。Flax NNX 简化了这一点,因此,如果将out_axes
设置为nnx.Carry
而不是默认的(nnx.Carry, 0)
,你现在可以选择忽略第二个输出。这是 Flax NNX 变换与 JAX 变换 API 不同的少数情况之一。
Flax NNX 上面的示例中有更多的代码行,但它们更精确地表达了每次发生的事情。由于 Flax NNX 变换变得更接近 JAX 变换 API,因此建议在使用其 Flax NNX 等效项之前,对底层的 JAX 变换 有很好的理解。
现在检查两边的变量 pytree。
# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
在 Flax NNX 中使用 TrainState
#
Flax Linen 有一个方便的 TrainState
数据类来捆绑模型、参数和优化器。在 Flax NNX 中,这并不是真正必要的。在本节中,你将学习如何围绕 TrainState
构建你的 Flax NNX 代码,以满足任何向后兼容性需求。
在 Flax NNX 中
你必须首先在模型上调用
nnx.split
,以获得单独的nnx.GraphDef
和nnx.State
对象。你还需要子类化
TrainState
以添加其他变量的字段。然后,你可以传入
nnx.GraphDef.apply
作为apply
函数,nnx.State
作为参数和其他变量,以及一个优化器作为TrainState
构造函数的参数。
请注意,nnx.GraphDef.apply
将接受 nnx.State
对象作为参数,并返回一个可调用函数。可以对输入调用此函数以输出模型的 logits,以及更新后的 nnx.GraphDef
和 nnx.State
对象。请注意下面使用了 @jax.jit
,因为你没有将 Flax NNX 模块传递到 train_step
中。
from flax.training import train_state
sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
inputs, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
apply_fn=graphdef.apply,
params=params,
other_variables=other_variables,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, other_variables):
logits, (graphdef, new_state) = state.apply_fn(
params,
other_variables
)(inputs) # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params, state.other_variables)
state = state.apply_gradients(grads=grads)
return state