亚麻基础#
本笔记本将引导您完成以下工作流程
从亚麻内置层或第三方模型实例化模型。
初始化模型的参数并手动编写训练。
使用亚麻提供的优化器来简化训练。
参数和其他对象的序列化。
创建您自己的模型并管理状态。
设置我们的环境#
这里我们提供了设置笔记本环境所需的代码。
# Install the latest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git
WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pythonlang.cn/warnings/venv
WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pythonlang.cn/warnings/venv
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
使用亚麻进行线性回归#
在之前的JAX for the impatient笔记本中,我们完成了线性回归示例。我们知道,线性回归也可以写成一个单一的密集神经网络层,我们将在下面展示它,以便我们比较如何完成它。
密集层是一个具有内核参数\(W\in\mathcal{M}_{m,n}(\mathbb{R})\)的层,其中\(m\)是模型输出的特征数,而\(n\)是输入的维数,以及一个偏差参数\(b\in\mathbb{R}^m\)。密集层从输入\(x\in\mathbb{R}^n\)返回\(Wx+b\)。
此密集层已在 flax.linen
模块中由亚麻提供(这里导入为 nn
)。
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)
层(以及一般模型,我们将在后面使用该词)是 linen.Module
类的子类。
模型参数和初始化#
参数不会与模型本身一起存储。您需要使用 PRNGKey 和虚拟输入数据调用 init
函数来初始化参数。
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
FrozenDict({
params: {
bias: (5,),
kernel: (10, 5),
},
})
注意:JAX 和亚麻与 NumPy 一样,是基于行的系统,这意味着向量表示为行向量,而不是列向量。这可以在此处的内核形状中看到。
结果是我们预期的:大小正确的偏差和内核参数。幕后
虚拟输入数据
x
用于触发形状推断:我们只声明了模型输出中所需的特征数量,而不是输入的大小。亚麻会自动找出内核的正确大小。随机 PRNG 密钥用于触发初始化函数(这些函数具有由模块在此处提供的默认值)。
初始化函数被调用以生成模型将使用的初始参数集。这些函数接受
(PRNG Key, shape, dtype)
作为参数,并返回形状为shape
的数组。init 函数返回已初始化的参数集(您还可以使用与
init
相同的语法,使用init_with_output
方法而不是init
来获取虚拟输入上正向传递的输出)。
要使用给定参数集(这些参数从未与模型一起存储)对模型进行正向传递,我们只需使用 apply
方法,向它提供要使用的参数以及输入
model.apply(params, x)
DeviceArray([-0.7358944, 1.3583755, -0.7976872, 0.8168598, 0.6297793], dtype=float32)
梯度下降#
如果您直接跳到这里而没有阅读 JAX 部分,那么我们将使用以下线性回归公式:从一组数据点\(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\)开始,我们尝试找到一组参数\(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\),使得函数\(f_{W,b}(x)=Wx+b\)最小化均方误差
在这里,我们看到元组\((W,b)\)与 Dense 层的参数匹配。我们将使用这些参数执行梯度下降。让我们首先生成我们将使用的假数据。数据与 JAX 部分中的 Pytree 线性回归示例完全相同。
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5
# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)
我们使用 jax.value_and_grad()
复制了在 JAX Pytree 线性回归示例中使用的相同训练循环,但在这里我们可以使用 model.apply()
而不是必须定义我们自己的前馈函数(predict_pytree()
在 JAX 示例中)。
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
# Define the squared loss for a single pair (x,y)
def squared_error(x, y):
pred = model.apply(params, x)
return jnp.inner(y-pred, y-pred) / 2.0
# Vectorize the previous to compute the average of the loss on all samples.
return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
最后执行梯度下降。
learning_rate = 0.3 # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
for i in range(101):
# Perform one gradient update.
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = update_params(params, learning_rate, grads)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
Loss for "true" W,b: 0.023639778
Loss step 0: 38.094772
Loss step 10: 0.44692168
Loss step 20: 0.10053458
Loss step 30: 0.035822745
Loss step 40: 0.018846875
Loss step 50: 0.013864839
Loss step 60: 0.012312559
Loss step 70: 0.011812928
Loss step 80: 0.011649306
Loss step 90: 0.011595251
Loss step 100: 0.0115773035
使用 Optax 优化#
亚麻过去使用自己的 flax.optim
包进行优化,但随着 FLIP #1009 的发布,它被弃用,转而使用 Optax。
Optax 的基本用法很简单
选择一种优化方法(例如
optax.adam
)。从参数创建优化器状态(对于 Adam 优化器,此状态将包含 动量值)。
使用
jax.value_and_grad()
计算损失的梯度。在每次迭代中,调用 Optax
update
函数来更新内部优化器状态并创建对参数的更新。然后使用 Optax 的apply_updates
方法将更新添加到参数中。
请注意,Optax 可以做更多的事情:它旨在将简单的梯度变换组合成更复杂的变换,从而实现各种优化器。它还支持随着时间推移更改优化器超参数(“调度”)、将不同的更新应用于参数树的不同部分(“掩码”)等等。有关详细信息,请参阅 官方文档。
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print('Loss step {}: '.format(i), loss_val)
Loss step 0: 0.011576377
Loss step 10: 0.0115710115
Loss step 20: 0.011569244
Loss step 30: 0.011568661
Loss step 40: 0.011568454
Loss step 50: 0.011568379
Loss step 60: 0.011568358
Loss step 70: 0.01156836
Loss step 80: 0.01156835
Loss step 90: 0.011568353
Loss step 100: 0.011568348
序列化结果#
现在我们对训练结果感到满意,我们可能希望保存模型参数以便以后加载。亚麻提供了一个序列化包,使您可以做到这一点。
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
Dict output
{'params': {'bias': DeviceArray([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547], dtype=float32), 'kernel': DeviceArray([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285,
0.34720102],
[ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 ,
-0.10574618],
[-1.2009128 , 0.28837162, 1.4176372 , 0.12073109,
-1.3132601 ],
[-1.1944956 , -0.18993308, 0.03379077, 1.3165942 ,
0.07996067],
[ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774,
-2.239638 ],
[ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 ,
0.90352124],
[-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 ,
0.9286919 ],
[ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117,
-1.202457 ],
[ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 ,
-0.45790705],
[-0.64384323, 0.4564892 , -1.1331053 , -0.68556863,
0.17010891]], dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14\x1d\x1d\xba\xbf\xc4\xad\x01\xc0\x81)\x05@\xdd.\x9c?\xa8\x17\x7f\xbf\xa6kernel\xc7\xd6\x01\x93\x92\n\x05\xa7float32\xc4\xc8\x84]\x81?\xf0\xb5B>`\xb59=z^m\xbfU\xc4\xb1>\x00\xb3\xdd?\xb8x}?\xc7F\x95?2(\x8d?t\x91\xd8\xbd\x83\xb7\x99\xbfr\xa5\x93>#u\xb5?\xdcA\xf7=\xe8\x18\xa8\xbf;\xe5\x98\xbf\xd1}B\xbe0h\n=)\x86\xa8?k\xc2\xa3=\xaaj\x10>\x91\xd8\xaf?\xa9y\xa8\xbfc\xb5\x08?;V\x0f\xc0Av\x10?ZHP?wD\xa3>\x022\t?+Mg?\xa0K\xc2\xbe\xb1\xd3\xde?)\x16\x8a?\x04\x13\x01\xbf\xc1\xbem?\xfdZx?Wn\xa8\xbf\x940\xac>\x925O?\x1c\xea\x99\xbf\x9e\x89\x82?\x07\xad\x1e\xbf\xe2\x87\x8a?\xdfU\xeb\xbf\xcbr\xea\xbe\xe9\xd2$\xbf\xf4\xb8\xe9>\x98\t\x91\xbfm\x81/\xbf\x081.>'
要加载模型,您需要使用模型参数结构的模板,例如您从模型初始化获得的模板。在这里,我们使用先前生成的 params
作为模板。请注意,这将产生一个新的变量结构,而不是就地修改。
通过模板强制执行结构的目的是避免用户在后续操作中出现问题,因此您需要先拥有生成参数结构的正确模型。
serialization.from_bytes(params, bytes_output)
FrozenDict({
params: {
bias: array([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547],
dtype=float32),
kernel: array([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285, 0.34720102],
[ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 , -0.10574618],
[-1.2009128 , 0.28837162, 1.4176372 , 0.12073109, -1.3132601 ],
[-1.1944956 , -0.18993308, 0.03379077, 1.3165942 , 0.07996067],
[ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774, -2.239638 ],
[ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 , 0.90352124],
[-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 , 0.9286919 ],
[ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117, -1.202457 ],
[ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 , -0.45790705],
[-0.64384323, 0.4564892 , -1.1331053 , -0.68556863, 0.17010891]],
dtype=float32),
},
})
定义您自己的模型#
亚麻允许您定义您自己的模型,这些模型应该比线性回归更复杂。在本节中,我们将向您展示如何构建简单的模型。为此,您需要创建基本 nn.Module
类的子类。
请记住,我们导入 linen as nn
,并且这仅适用于新的 linen API
模块基础#
模型的基本抽象是 nn.Module
类,亚麻中的每种类型的预定义层(如之前的 Dense
)都是 nn.Module
的子类。让我们看一下,并首先定义一个简单的自定义多层感知器,即一系列密集层,其间穿插着对非线性激活函数的调用。
class ExplicitMLP(nn.Module):
features: Sequence[int]
def setup(self):
# we automatically know what to do with lists, dicts of submodules
self.layers = [nn.Dense(feat) for feat in self.features]
# for single submodules, we would just write:
# self.layer1 = nn.Dense(feat1)
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
{'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
[[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03
-1.7147182e-02]
[ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02
-4.5417298e-02]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00]
[ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04
-1.0110770e-03]]
如我们所见,nn.Module
子类由以下组成
一个数据字段的集合(
nn.Module
是 Python 的数据类) - 这里我们只有类型为Sequence[int]
的features
字段。一个
setup()
方法,它在__postinit__
的末尾被调用,您可以在其中注册模型所需的子模块、变量和参数。一个
__call__
函数,它从给定输入返回模型的输出。模型结构定义了参数的 pytree,该参数遵循与模型相同的树结构:params 树为每个层包含一个
layers_n
子字典,每个子字典都包含关联的 Dense 层的参数。布局非常明确。
注意:列表的管理方式与您的预期基本一致(WIP),您应该注意一些特殊情况,如 here
由于模块结构及其参数彼此不绑定,因此您不能直接对给定输入调用 model(x)
,因为它将返回错误。 __call__
函数被封装在 apply
函数中,它是用于对输入进行调用的函数
try:
y = model(x) # Returns an error
except AttributeError as e:
print(e)
"ExplicitMLP" object has no attribute "layers"
由于这里我们有一个非常简单的模型,我们可以使用另一种(但等效的)方法在 __call__
中使用 @nn.compact
注解内联声明子模块,如下所示
class SimpleMLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
# providing a name is optional though!
# the default autonames would be "Dense_0", "Dense_1", ...
return x
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
{'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
[[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03
-1.7147182e-02]
[ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02
-4.5417298e-02]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00]
[ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04
-1.0110770e-03]]
但是,您应该注意这两种声明模式之间存在一些差异
在
setup
中,您可以命名一些子层并将其保留以供将来使用(例如,自动编码器中的编码器/解码器方法)。如果要使用多个方法,则**必须**使用
setup
声明模块,因为@nn.compact
注解仅允许注释一个方法。最后一次初始化将以不同的方式处理。有关更多详细信息,请参阅以下说明(TODO:添加说明链接)。
模块参数#
在之前的 MLP 示例中,我们仅依赖于预定义的层和运算符(Dense
、relu
)。假设您没有 Flax 提供的 Dense 层,并且想要自己编写它。以下是使用 @nn.compact
方法声明新模块的方式
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros_init()
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init, # Initialization function
(inputs.shape[-1], self.features)) # shape info.
y = jnp.dot(inputs, kernel)
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameters:\n', params)
print('output:\n', y)
initialized parameters:
FrozenDict({
params: {
kernel: DeviceArray([[ 0.6503669 , 0.86789787, 0.4604268 ],
[ 0.05673932, 0.9909285 , -0.63536596],
[ 0.76134115, -0.3250529 , -0.65221626],
[-0.82430327, 0.4150194 , 0.19405058]], dtype=float32),
bias: DeviceArray([0., 0., 0.], dtype=float32),
},
})
output:
[[ 0.5035518 1.8548558 -0.4270195 ]
[ 0.0279097 0.5589246 -0.43061772]
[ 0.3547128 1.5740999 -0.32865518]
[ 0.5264864 1.2928858 0.10089308]]
在这里,我们看到如何使用 self.param
方法声明和分配模型的参数。它接受 (name, init_fn, *init_args, **init_kwargs)
作为输入
name
只是最终出现在参数结构中的参数的名称。init_fn
是一个具有输入(PRNGKey, *init_args, **init_kwargs)
的函数,返回一个数组,init_args
和init_kwargs
是调用初始化函数所需的论据。init_args
和init_kwargs
是要提供给初始化函数的论据。
此类参数也可以在 setup
方法中声明;它将无法使用形状推断,因为 Flax 在第一个调用站点使用延迟初始化。
变量和变量集合#
到目前为止,我们已经看到,使用模型意味着使用
nn.Module
的子类;模型的参数 pytree(通常来自
model.init()
);
但是,这还不够涵盖机器学习(尤其是神经网络)所需的一切。在某些情况下,您可能希望神经网络在运行时跟踪一些内部状态(例如,批处理归一化层)。可以使用 variable
方法声明超出模型参数的变量。
为了演示目的,我们将实现一个简化但类似于批处理归一化的机制:我们将存储运行平均值并在训练时将其减去输入。对于正确的批处理归一化,您应该使用(并查看)这里的实现。
class BiasAdderWithRunningMean(nn.Module):
decay: float = 0.99
@nn.compact
def __call__(self, x):
# easy pattern to detect if we're initializing via empty variable tree
is_initialized = self.has_variable('batch_stats', 'mean')
ra_mean = self.variable('batch_stats', 'mean',
lambda s: jnp.zeros(s),
x.shape[1:])
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
return x - ra_mean.value + bias
key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
initialized variables:
FrozenDict({
batch_stats: {
mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
},
params: {
bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
},
})
updated state:
FrozenDict({
batch_stats: {
mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
},
})
这里,updated_state
仅返回在模型上应用数据时被模型修改的状态变量。要更新变量并获取模型的新参数,我们可以使用以下模式
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
updated state:
FrozenDict({
batch_stats: {
mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
},
})
updated state:
FrozenDict({
batch_stats: {
mean: DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),
},
})
updated state:
FrozenDict({
batch_stats: {
mean: DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),
},
})
从这个简化的示例中,您应该能够推导出完整的 BatchNorm 实现,或任何涉及状态的层。最后,让我们添加一个优化器,看看如何使用优化器更新的参数和状态变量。
这个示例并没有执行任何操作,仅用于演示目的。
from functools import partial
@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):
def loss(params):
y, updated_state = apply_fn({'params': params, **state},
x, mutable=list(state.keys()))
l = ((x - y) ** 2).sum()
return l, updated_state
(l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return opt_state, params, state
x = jnp.ones((10,5))
variables = model.init(random.key(0), x)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
for _ in range(3):
opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
print('Updated state: ', state)
Updated state: FrozenDict({
batch_stats: {
mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
},
})
Updated state: FrozenDict({
batch_stats: {
mean: DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),
},
})
Updated state: FrozenDict({
batch_stats: {
mean: DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),
},
})
请注意,上述函数的签名非常冗长,实际上无法与 jax.jit()
一起使用,因为函数参数不是“有效的 JAX 类型”。
Flax 提供了一个方便的包装器 - TrainState
- 它简化了上述代码。查看 flax.training.train_state.TrainState
以了解更多信息。