目录
动机#
我们当前的 API(在本文件中称为 先前 API)使用一种模式,其中 Optimizer
数据类是从 target
变量的 pytree 和定义如何更新优化器状态、超参数和目标变量的 OptimizerDef
创建的。这种模式对于实现简单的优化器来说比较复杂,而在典型的 Linen 训练步骤中却很冗长(尤其是在使用可变状态集合时)。
此包 flax.optim
包含一些优化器,但这个列表远非详尽,理想情况下,我们会使用来自专用 PyPi 包的 JAX 优化器。
DeepMind 已经有一个专用库 - Optax - 它实现了广泛的有趣的优化器,并提供了一个框架,可以将可重用梯度转换组合成新的优化器。
使用 Optax#
梯度转换#
虽然 Optax 确实提供了预定义的优化器(如 optax.adam
或带有动量的 optax.sgd
),但它实际上是一个梯度转换库,创建优化器的惯用方式是提供这些梯度转换的组合。要模仿使用 先前 API 时示例中的动量优化器,我们将编写
import optax
tx = optax.chain(
optax.trace(decay=0.9, nesterov=False),
optax.scale_by_schedule(lambda step: -get_learning_rate(step)),
)
备注
上面的梯度转换将等效于在 优化器和 OptimizerDef 下定义的示例,我们将在其中定义没有 Nesterov 动量的动量优化器(请注意,
beta
参数对应于optax.trace()
转换的decay
参数,学习率在第二个链接的转换中应用)。请注意,像
decay
或nesterov
这样的超参数仅存在于返回GradientTransformation
的高阶函数的内部作用域中。这种梯度转换当前被定义为init()
和update()
函数的NamedTuple
。原则上,这种模式可以扩展到也存储超参数,这可能是 Optax 存储库中需要讨论的一个点。我们可以在定义 Optax 梯度更新转换时使用一个
get_learning_rate()
,它根据步骤号返回学习率。上面的代码说明了这如何成为我们还在 先前训练步骤 中使用的函数的直接替代,其中此更新函数已经存在(请注意,我们需要反转符号,因为我们将梯度更新添加到参数中)。此外,您可以使用inject_hyperparams()
来调度 Optax 的任意超参数。
Optax 训练步骤#
@functools.partial(jax.jit, static_argnums=(4, 5))
def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn):
def loss_fn(params):
logits, new_model_state = apply_fn(
{**variables, 'params': params}, inputs, mutable=['batch_stats'])
loss = xent_loss(logits, labels)
return loss, new_model_state
variables, params = variables.pop('params')
(loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
params)
updates, new_opt_state = tx_update_fn(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
new_variables = {**variables, **new_model_state, 'params': new_params}
return new_opt_state, new_variables, loss
opt_state = tx.init(variables['params'])
for batch in ds.as_numpy_iterator():
opt_state, variables, loss = train_step(
opt_state, variables, batch['image'], batch['label'], model.apply,
tx.update)
print(loss)
备注
由于
tx.update()
仅转换梯度,因此我们仍然需要调用optax.apply_updates()
来将这些转换后的梯度应用于参数。与 先前 API 相比,我们现在可以将整个
variables
(包括params
)作为train_step()
的输入和输出。在训练步骤中仍然需要将
params
与variables
分开,因为我们只想计算相对于params
而不是整个variables
的梯度。我们仍然可以记录内部优化器状态,例如学习率,只要 Optax 转换在各自状态中公开该信息即可。例如,
optax.scale_by_schedule()
目前仅公开了opt_state.count
,但可以轻松地扩展以也公开step_size
。对于随时间变化的内部优化器状态也是如此。
多优化器#
先前 API 定义了 flax.optim.MultiOptimizer
,用于使用不同的优化器处理参数树的不同部分
biases_traversal = flax.optim.ModelParamTraversal(
lambda path, _: path.endswith('/bias'))
not_biases_traversal = flax.optim.ModelParamTraversal(
lambda path, _: not path.endswith('/bias'))
optimizer_def = flax.optim.MultiOptimizer(
(biases_traversal, flax.optim.GradientDescent(learning_rate=0.1)),
(not_biases_traversal, flax.optim.GradientDescent(learning_rate=0.05)),
)
请注意,我们首先定义一个遍历,该遍历根据参数路径(即模块范围和变量名的连接)选择参数,然后创建一个 MultiOptimizer
,该 MultiOptimizer
为每个单独的遍历绑定不同的优化器。
Optax 最近实现了 optax.masked()
,它可以用于指定仅应用于梯度子集的梯度转换
def flattened_traversal(fn):
def mask(data):
flat = traverse_util.flatten_dict(data)
return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
return mask
tx = optax.chain(
optax.masked(optax.sgd(learning_rate=0.1),
mask=flattened_traversal(lambda path, _: path[-1] == 'bias')),
optax.masked(optax.sgd(learning_rate=0.05),
mask=flattened_traversal(lambda path, _: path[-1] != 'bias')),
)
训练状态#
在 Flax 中,通常会传递一个 TrainState
对象,然后可以使用该对象进行检查点。这通过减少参数数量和消除 static_argnums
来简化上面的 Optax 训练步骤。
我们可以定义一个 TrainState
数据类,它包装通过应用梯度来更新优化器状态和参数的常见模式。
# Small helper class in flax.training
class TrainState(flax.struct.PyTreeNode):
step: int
apply_fn: Callable = flax.struct.field(pytree_node=False)
params: flax.core.FrozenDict[str, Any]
tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
opt_state: optax.OptState
def apply_gradients(self, *, grads, **kwargs):
updates, new_opt_state = self.tx.update(
grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
opt_state = tx.init(params)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)
用户可以从该数据类派生,并添加更多字段,例如可变模型状态
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict[str, Any]
有了它,Optax 训练步骤 变成
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params):
outputs, new_model_state = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
inputs,
mutable=['batch_stats'])
loss = xent_loss(outputs, labels)
return loss, new_model_state
(loss, new_model_state), grads = jax.value_and_grad(
loss_fn, has_aux=True)(state.params)
new_state = state.apply_gradients(
grads=grads,
batch_stats=new_model_state['batch_stats'],
)
return new_state, loss
state = TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx,
batch_stats=variables['batch_stats'],
)
for batch in ds.as_numpy_iterator():
state, loss = train_step(state, batch['image'], batch['label'])
没有可变状态的训练步骤简化为
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params):
outputs = state.apply_fn({'params': params}, inputs)
loss = xent_loss(outputs, labels)
return loss
loss, grads = jax.value_and_grad(loss_fn)(state.params)
new_state = state.update(grads=grads)
return new_state, loss
state = flax.training.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx,
)
for batch in ds.as_numpy_iterator():
state, loss = train_step(state, batch['image'], batch['label'])
备注
在 Flax 训练循环中,有一个
TrainState
数据类在每一步后用新状态更新是一个常见模式。在
flax.training.train_state
中提出的简单解决方案可以扩展为包含更多数据,但高级用例(例如多个不同的模型和/或优化器)不受支持。用户应该改而对数据类进行分支,并根据自己的需要重新实现它。与 先前 API 中的
Optimizer
抽象不同,TrainState
现在直接包含.params
,无需再通过.optimizer
先前 API#
优化器和 OptimizerDef#
优化器本身将通过创建一个从 OpimizerDef
派生的新类来实现
# flax/optim/momentum.py
@flax.struct.dataclass
class _MomentumHyperParams:
learning_rate: jnp.ndarray
beta: jnp.ndarray
@flax.struct.dataclass
class _MomentumParamState:
momentum: np.ndarray
class Momentum(flax.optim.OptimizerDef):
def __init__(self, learning_rate=None, beta=0.9):
super().__init__(
_MomentumHyperParams(learning_rate, beta)
)
def init_param_state(self, param):
return _MomentumParamState(jnp.zeros_like(param))
def apply_param_gradient(self, step, hyper_params, param, state, grad):
del step
assert hyper_params.learning_rate is not None
new_momentum = state.momentum * hyper_params.beta + grad
new_params = param - hyper_params.learning_rate * new_momentum
return new_params, _MomentumParamState(new_momentum)
备注
注意
OptimizerDef
和Optimizer
之间的关系:当用户代码调用函数Optimizer.apply_gradient()
时,它会调用OptimizerDef.apply_gradient()
(以及其他操作),进而调用OptimizerDef.apply_param_gradient()
(由OptimizerDef
的子类实现)。函数
init_param_state()
和apply_param_gradient()
会对 params/grads pytree 中的每个叶节点调用。这使得可以直接编写计算,无需使用jax.tree_util.tree_map()
。在 Linen 之前,接口的定义没有考虑
params
与variables
中其他集合的区别。最初的 API 非常优雅,因为只需要传递优化器,其中包含参数、优化器状态、优化器超参数以及对OptimizerDef
的引用即可执行参数/状态更新。
之前的训练步骤#
优化器首先会根据其定义和目标参数的 pytree 进行构建
optimizer_def = flax.optim.Momentum(learning_rate=0.1, beta=0.9)
optimizer = optimizer_def.create(variables['params'])
然后,目标变量会在训练步骤中被优化(假设只有一个非参数集合“batch_stats”)
def make_train_step(apply_fn):
@jax.jit
def train_step(optimizer, batch_stats, inputs, labels):
def loss_fn(params):
variables = {'params': params, 'batch_stats': batch_stats}
logits, new_model_state = apply_fn(
variables, inputs, mutable=['batch_stats'])
loss = xent_loss(logits, labels)
return loss, new_model_state['batch_stats']
(loss, new_batch_stats), grad = jax.value_and_grad(loss_fn, has_aux=True)(
optimizer.target)
lr = get_learning_rate(step)
new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
return new_optimizer, new_batch_stats, loss
return train_step
batch_stats = variables['batch_stats']
train_step = make_train_step(model.apply)
for step, batch in enumerate(ds)
optimizer, batch_stats, loss = train_step(
optimizer, batch_stats, batch['image'], batch['label'])
备注
注意,
optimizer.apply_gradient()
可以接收额外的参数来更新超参数,例如在本例中来自独立函数get_learning_rate()
的学习率。
更新计划#
完成关于此 FLIP 的讨论
为 Optax 添加 等价性测试,确保现有的
flax.optim
优化器与相应的optax
优化器返回相同的值。更新示例以使用 Optax,并验证它们在相同的计算成本下达到了相同的最终性能。
将 Optax 中缺少的优化器移植(例如 Adafactor),并验证上述要点。
更新所有文档(包括 README、Flax 指南、HOWTO 等),专门介绍 Optax 优化器。
创建一个从
flax.optim
迁移到使用 Optax 的过渡指南。该指南还应指向 Optax 的 等价性测试 和更新示例的 pull 请求。将
flax.optim
中的优化器标记为已弃用。
注意,所有当前的 Flax 示例都使用 Optax 中已有的优化器
示例 |
Flax |
Optax |
评论 |
---|---|---|---|
imagenet |
optim.Momentum |
optax.sgd |
DynamicScale 可以保持不变。 |
mnist |
optim.Momentum |
optax.sgd |
|
nlp_seq |
optim.Adam |
optax.adamw |
|
pixelcnn |
optim.Adam |
optax.adam |
|
ppo |
optim.Adam |
optax.adam |
|
seq2seq |
optim.Adam |
optax.adam |
|
vae |
optim.Adam |
optax.adam |
|
wmt |
optim.Adam |
optax.adamw |
(Flax 的 Adam 实现有一个可选参数用于权重衰减,但在 Optax 中,带权重衰减和不带权重衰减的 Adam 是两个不同的别名。)
附录#
设置代码#
以下设置代码可用于运行此 FLIP 中的代码段
import functools
from typing import Callable, Sequence
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import tensorflow as tf
import tensorflow_datasets as tfds
def pp(features):
return {
'image': tf.cast(features['image'], tf.float32) / 255 - 0.5,
'label': features['label'],
}
class Model(nn.Module):
@nn.compact
def __call__(self, inputs):
x = inputs.reshape([inputs.shape[0], -1])
x = nn.normalization.BatchNorm(True)(x)
x = nn.Dense(10)(x)
x = nn.log_softmax(x)
return x
def onehot(labels, num_classes, on_value=1.0, off_value=0.0):
x = (labels[..., None] == jnp.arange(num_classes)[None])
x = jax.lax.select(
x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value))
return x.astype(jnp.float32)
def xent_loss(logits, labels):
return -jnp.sum(
onehot(labels, num_classes=10) * logits) / labels.size
def get_learning_rate(step):
return 0.1
model = Model()
rng = jax.random.key(0)
ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16)
batch = next(iter(ds))
variables = model.init(rng, jnp.array(batch['image'][:1]))
jax.tree_util.tree_map(jnp.shape, variables)