将我的代码库升级到Linen#
从Flax v0.4.0开始,flax.nn
不再存在,并被位于flax.linen
的新Linen API替换。如果您的代码库仍在使用旧的API,您可以使用此升级指南将其升级到Linen。
定义简单的Flax模块#
from flax import nn
class Dense(base.Module):
def apply(self,
inputs,
features,
use_bias=True,
kernel_init=default_kernel_init,
bias_init=initializers.zeros_init()):
kernel = self.param('kernel',
(inputs.shape[-1], features), kernel_init)
y = jnp.dot(inputs, kernel)
if use_bias:
bias = self.param(
'bias', (features,), bias_init)
y = y + bias
return y
from flax import linen as nn # [1]
class Dense(nn.Module):
features: int # [2]
use_bias: bool = True
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
@nn.compact
def __call__(self, inputs): # [3]
kernel = self.param('kernel',
self.kernel_init, (inputs.shape[-1], self.features)) # [4]
y = jnp.dot(inputs, kernel)
if self.use_bias:
bias = self.param(
'bias', self.bias_init, (self.features,)) # [5]
y = y + bias
return y
将
from flax import nn
替换为from flax import linen as nn
。将参数移动到
apply
中成为dataclass属性。添加类型注释(或使用类型Any
跳过)。将方法
apply
重命名为__call__
,并(可选)用@compact
包装。用@compact
包装的方法可以直接在方法内定义子模块(就像在旧的Flax中一样)。您只能用@compact
包装一个方法。或者,您可以定义一个setup
方法。有关更多详细信息,请参阅我们的其他HOWTO 我应该使用setup还是nn.compact?。在方法内部通过
self.<attr>
访问dataclass属性值,例如self.features
。将形状移动到
self.param
的参数末尾(初始化器函数可以接受任意参数列表)。
在其他模块中使用Flax模块#
class Encoder(nn.Module):
def apply(self, x):
x = nn.Dense(x, 500)
x = nn.relu(x)
z = nn.Dense(x, 500, name="latents")
return z
class Encoder(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(500)(x) # [1]
x = nn.relu(x)
z = nn.Dense(500, name='latents')(x) # [2]
return z
模块构造函数不再返回输出。相反,它们的工作方式与普通构造函数一样,并返回模块实例。这些实例可以像在普通Python中一样共享(而不是在旧的Flax中使用
.shared()
)。由于大多数模块都实现了__call__
,因此您可以保留旧Flax的简洁性。名称可以可选地传递给所有模块构造函数。
共享子模块和定义多个方法#
class AutoEncoder(nn.Module):
def _create_submodules(self):
return Decoder.shared(name="encoder")
def apply(self, x, z_rng, latents=20):
decoder = self._create_decoder()
z = Encoder(x, latents, name="encoder")
return decoder(z)
@nn.module_method
def generate(self, z, **unused_kwargs):
decoder = self._create_decoder()
return nn.sigmoid(decoder(z))
class AutoEncoder(nn.Module):
latents: int = 20
def setup(self): # [1]
self.encoder = Encoder(self.latents) # [2]
self.decoder = Decoder()
def __call__(self, x): # [3]
z = self.encoder(x)
return self.decoder(z)
def generate(self, z): # [4]
return nn.sigmoid(self.decoder(z))
Module.partial
在其他模块中#
# no import
class ResNet(nn.Module):
"""ResNetV1."""
def apply(self, x,
stage_sizes,
num_filters=64,
train=True):
conv = nn.Conv.partial(bias=False)
norm = nn.BatchNorm.partial(
use_running_average=not train,
momentum=0.9, epsilon=1e-5)
x = conv(x, num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')
x = norm(x, name='bn_init')
# [...]
return x
from functools import partial
class ResNet(nn.Module):
"""ResNetV1."""
stage_sizes: Sequence[int]
num_filters: int = 64
train: bool = True
@nn.compact
def __call__(self, x):
conv = partial(nn.Conv, use_bias=False)
norm = partial(nn.BatchNorm,
use_running_average=not self.train,
momentum=0.9, epsilon=1e-5)
x = conv(self.num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')(x)
x = norm(name='bn_init')(x)
# [...]
return x
使用普通的functools.partial
而不是Module.partial
。其余部分保持不变。
顶级训练代码模式#
def create_model(key):
_, initial_params = CNN.init_by_shape(
key, [((1, 28, 28, 1), jnp.float32)])
model = nn.Model(CNN, initial_params)
return model
def create_optimizer(model, learning_rate):
optimizer_def = optim.Momentum(learning_rate=learning_rate)
optimizer = optimizer_def.create(model)
return optimizer
def cross_entropy_loss(*, logits, labels):
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
def loss_fn(model):
logits = model(batch['image'])
one_hot = jax.nn.one_hot(batch['label'], num_classes=10)
loss = -jnp.mean(jnp.sum(one_hot_labels * batch['label'],
axis=-1))
return loss, logits
def create_train_state(rng, config): # [1]
variables = CNN().init(rng, jnp.ones([1, 28, 28, 1])) # [2]
params = variables['params'] # [3]
tx = optax.sgd(config.learning_rate, config.momentum) # [4]
return train_state.TrainState.create(
apply_fn=CNN.apply, params=params, tx=tx)
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image']) # [5]
one_hot = jax.nn.one_hot(batch['label'], 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits,
labels=one_hot))
return loss, logits
我们不再使用
Model
抽象——而是直接传递参数,通常封装在TrainState对象中,该对象可以直接传递给JAX转换。要计算初始参数,构造一个模块实例并调用
init
或init_with_output
。我们没有移植init_by_shape
,因为此函数做了一些我们不喜欢的魔术(它根据形状评估函数,但无论如何都返回真实值)。因此,您现在应该将具体值传递给初始化器函数,并且可以通过用jax.jit
包装它来优化初始化,这强烈建议避免运行完整的正向传递。Linen将参数概括为变量。参数是变量的其中一个“集合”。变量是嵌套字典,其中顶级键反映了不同的变量集合,“param”是其中之一。有关更多详细信息,请参阅变量文档。
我们建议使用Optax优化器。有关更多详细信息,请参阅我们名为将我的代码库升级到Optax的单独HOWTO。
要使用您的模型进行预测,请在顶级创建一个实例(这是免费的——只是构造函数属性的包装器)并调用
apply
方法(它将在内部调用__call__
)。
不可训练变量(“状态”):在模块中使用#
class BatchNorm(nn.Module):
def apply(self, x):
# [...]
ra_mean = self.state(
'mean', (x.shape[-1], ), initializers.zeros_init())
ra_var = self.state(
'var', (x.shape[-1], ), initializers.ones_init())
# [...]
class BatchNorm(nn.Module):
def __call__(self, x):
# [...]
ra_mean = self.variable(
'batch_stats', 'mean', initializers.zeros_init(), (x.shape[-1], ))
ra_var = self.variable(
'batch_stats', 'var', initializers.ones_init(), (x.shape[-1], ))
# [...]
第一个参数是变量集合的名称(“param”是始终可用的唯一变量集合)。一些集合可能会被视为可变的,而其他集合在顶级训练代码中被视为不可变的(有关详细信息,请参阅下一节)。Flax还允许您在模块内使用JAX转换时以不同的方式处理每个变量集合。
不可训练变量(“状态”):顶级训练代码模式#
# initial params and state
def initial_model(key, init_batch):
with nn.stateful() as initial_state:
_, initial_params = ResNet.init(key, init_batch)
model = nn.Model(ResNet, initial_params)
return model, init_state
# updates batch statistics during training
def loss_fn(model, model_state):
with nn.stateful(model_state) as new_model_state:
logits = model(batch['image'])
# [...]
# reads immutable batch statistics during evaluation
def eval_step(model, model_state, batch):
with nn.stateful(model_state, mutable=False):
logits = model(batch['image'], train=False)
return compute_metrics(logits, batch['label'])
# initial variables ({"param": ..., "batch_stats": ...})
def initial_variables(key, init_batch):
return ResNet().init(key, init_batch) # [1]
# updates batch statistics during training
def loss_fn(params, batch_stats):
variables = {'params': params, 'batch_stats': batch_stats} # [2]
logits, new_variables = ResNet(train=true).apply(
variables, batch['image'], mutable=['batch_stats']) # [3]
new_batch_stats = new_variables['batch_stats']
# [...]
# reads immutable batch statistics during evaluation
def eval_step(params, batch_stats, batch):
variables = {'params': params, 'batch_stats': batch_stats}
logits = ResNet(train=False).apply(
variables, batch['image'], mutable=False) # [4]
return compute_metrics(logits, batch['label'])
加载预Linen检查点#
虽然大多数Linen模块应该能够在没有任何修改的情况下使用预Linen权重,但有一个需要注意的地方:在预Linen API中,子模块按增量编号,独立于子模块类。使用Linen后,此行为已更改为保持每个模块类的单独子模块计数。
在预Linen中,参数具有以下结构
{'Conv_0': { ... }, 'Dense_1': { ... } }
在 Linen 中,改为如下:
{'Conv_0': { ... }, 'Dense_0': { ... } }
待办事项:在此处添加一个关于如何加载新的 TrainState
对象的示例。
随机性#
def dropout(inputs, rate, deterministic=False):
keep_prob = 1. - rate
if deterministic:
return inputs
else:
mask = random.bernoulli(
make_rng(), p=keep_prob, shape=inputs.shape)
return lax.select(
mask, inputs / keep_prob, jnp.zeros_like(inputs))
def loss_fn(model, dropout_rng):
with nn.stochastic(dropout_rng):
logits = model(inputs)
class Dropout(nn.Module):
rate: float
@nn.compact
def __call__(self, inputs, deterministic=False):
keep_prob = 1. - self.rate
if deterministic:
return inputs
else:
mask = random.bernoulli(
self.make_rng('dropout'), p=keep_prob, shape=inputs.shape) # [1]
return lax.select(
mask, inputs / keep_prob, jnp.zeros_like(inputs))
def loss_fn(params, dropout_rng):
logits = Transformer().apply(
{'params': params}, inputs, rngs={'dropout': dropout_rng}) # [2]
Linen 中的 RNG 具有“种类”——在本例中为
'dropout'
。在 JAX 变换中,不同的种类可以被不同地处理(例如,您是否希望在序列模型中的每个时间步长都使用相同的 dropout 掩码,或者使用不同的掩码?)您无需使用
nn.stochastic
上下文管理器,而是将 RNG 显式地传递给module.apply
。在评估期间,您不会传递任何 RNG——然后,如果您在非确定性模式下意外使用了 dropout,self.make_rng('dropout')
将引发错误。