• 开始日期:2021-02-08

  • FLIP PR:#1011

  • FLIP 问题:#1009

目录

摘要#

此 FLIP 提出用 Optax(DeepMind 的优化器库)替换我们当前的 flax.optim API(在本文件中称为 先前 API)。

动机#

我们当前的 API(在本文件中称为 先前 API)使用一种模式,其中 Optimizer 数据类是从 target 变量的 pytree 和定义如何更新优化器状态、超参数和目标变量的 OptimizerDef 创建的。这种模式对于实现简单的优化器来说比较复杂,而在典型的 Linen 训练步骤中却很冗长(尤其是在使用可变状态集合时)。

此包 flax.optim 包含一些优化器,但这个列表远非详尽,理想情况下,我们会使用来自专用 PyPi 包的 JAX 优化器。

DeepMind 已经有一个专用库 - Optax - 它实现了广泛的有趣的优化器,并提供了一个框架,可以将可重用梯度转换组合成新的优化器。

使用 Optax#

梯度转换#

虽然 Optax 确实提供了预定义的优化器(如 optax.adam 或带有动量的 optax.sgd),但它实际上是一个梯度转换库,创建优化器的惯用方式是提供这些梯度转换的组合。要模仿使用 先前 API 时示例中的动量优化器,我们将编写

import optax

tx = optax.chain(
    optax.trace(decay=0.9, nesterov=False),
    optax.scale_by_schedule(lambda step: -get_learning_rate(step)),
)

备注

  • 上面的梯度转换将等效于在 优化器和 OptimizerDef 下定义的示例,我们将在其中定义没有 Nesterov 动量的动量优化器(请注意,beta 参数对应于 optax.trace() 转换的 decay 参数,学习率在第二个链接的转换中应用)。

  • 请注意,像 decaynesterov 这样的超参数仅存在于返回 GradientTransformation 的高阶函数的内部作用域中。这种梯度转换当前被定义为 init()update() 函数的 NamedTuple。原则上,这种模式可以扩展到也存储超参数,这可能是 Optax 存储库中需要讨论的一个点。

  • 我们可以在定义 Optax 梯度更新转换时使用一个 get_learning_rate(),它根据步骤号返回学习率。上面的代码说明了这如何成为我们还在 先前训练步骤 中使用的函数的直接替代,其中此更新函数已经存在(请注意,我们需要反转符号,因为我们将梯度更新添加到参数中)。此外,您可以使用 inject_hyperparams() 来调度 Optax 的任意超参数。

Optax 训练步骤#

@functools.partial(jax.jit, static_argnums=(4, 5))
def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn):

  def loss_fn(params):
    logits, new_model_state = apply_fn(
        {**variables, 'params': params}, inputs, mutable=['batch_stats'])
    loss = xent_loss(logits, labels)
    return loss, new_model_state

  variables, params = variables.pop('params')
  (loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
      params)
  updates, new_opt_state = tx_update_fn(grads, opt_state, params)
  new_params = optax.apply_updates(params, updates)
  new_variables = {**variables, **new_model_state, 'params': new_params}
  return new_opt_state, new_variables, loss


opt_state = tx.init(variables['params'])
for batch in ds.as_numpy_iterator():
  opt_state, variables, loss = train_step(
      opt_state, variables, batch['image'], batch['label'], model.apply,
      tx.update)
  print(loss)

备注

  • 由于 tx.update() 仅转换梯度,因此我们仍然需要调用 optax.apply_updates() 来将这些转换后的梯度应用于参数。

  • 先前 API 相比,我们现在可以将整个 variables(包括 params)作为 train_step() 的输入和输出。

  • 在训练步骤中仍然需要将 paramsvariables 分开,因为我们只想计算相对于 params 而不是整个 variables 的梯度。

  • 我们仍然可以记录内部优化器状态,例如学习率,只要 Optax 转换在各自状态中公开该信息即可。例如,optax.scale_by_schedule() 目前仅公开了 opt_state.count,但可以轻松地扩展以也公开 step_size。对于随时间变化的内部优化器状态也是如此。

多优化器#

先前 API 定义了 flax.optim.MultiOptimizer,用于使用不同的优化器处理参数树的不同部分

biases_traversal = flax.optim.ModelParamTraversal(
    lambda path, _: path.endswith('/bias'))
not_biases_traversal = flax.optim.ModelParamTraversal(
    lambda path, _: not path.endswith('/bias'))

optimizer_def = flax.optim.MultiOptimizer(
    (biases_traversal, flax.optim.GradientDescent(learning_rate=0.1)),
    (not_biases_traversal, flax.optim.GradientDescent(learning_rate=0.05)),
)

请注意,我们首先定义一个遍历,该遍历根据参数路径(即模块范围和变量名的连接)选择参数,然后创建一个 MultiOptimizer,该 MultiOptimizer 为每个单独的遍历绑定不同的优化器。

Optax 最近实现了 optax.masked(),它可以用于指定仅应用于梯度子集的梯度转换

def flattened_traversal(fn):
  def mask(data):
    flat = traverse_util.flatten_dict(data)
    return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
  return mask

tx = optax.chain(
    optax.masked(optax.sgd(learning_rate=0.1),
                 mask=flattened_traversal(lambda path, _: path[-1] == 'bias')),
    optax.masked(optax.sgd(learning_rate=0.05),
                 mask=flattened_traversal(lambda path, _: path[-1] != 'bias')),
)

训练状态#

在 Flax 中,通常会传递一个 TrainState 对象,然后可以使用该对象进行检查点。这通过减少参数数量和消除 static_argnums 来简化上面的 Optax 训练步骤

我们可以定义一个 TrainState 数据类,它包装通过应用梯度来更新优化器状态和参数的常见模式。

# Small helper class in flax.training
class TrainState(flax.struct.PyTreeNode):
  step: int
  apply_fn: Callable = flax.struct.field(pytree_node=False)
  params: flax.core.FrozenDict[str, Any]
  tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
  opt_state: optax.OptState

  def apply_gradients(self, *, grads, **kwargs):
    updates, new_opt_state = self.tx.update(
        grads, self.opt_state, self.params)
    new_params = optax.apply_updates(self.params, updates)
    return self.replace(
        step=self.step + 1,
        params=new_params,
        opt_state=new_opt_state,
        **kwargs,
    )

  @classmethod
  def create(cls, *, apply_fn, params, tx, **kwargs):
    opt_state = tx.init(params)
    return cls(
        step=0,
        apply_fn=apply_fn,
        params=params,
        tx=tx,
        opt_state=opt_state,
        **kwargs,
    )

用户可以从该数据类派生,并添加更多字段,例如可变模型状态

from flax.training import train_state

class TrainState(train_state.TrainState):
  batch_stats: flax.core.FrozenDict[str, Any]

有了它,Optax 训练步骤 变成

@jax.jit
def train_step(state, inputs, labels):

  def loss_fn(params):
    outputs, new_model_state = state.apply_fn(
        {'params': params, 'batch_stats': state.batch_stats},
        inputs,
        mutable=['batch_stats'])
    loss = xent_loss(outputs, labels)
    return loss, new_model_state

  (loss, new_model_state), grads = jax.value_and_grad(
      loss_fn, has_aux=True)(state.params)
  new_state = state.apply_gradients(
      grads=grads,
      batch_stats=new_model_state['batch_stats'],
  )

  return new_state, loss


state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx,
    batch_stats=variables['batch_stats'],
)
for batch in ds.as_numpy_iterator():
  state, loss = train_step(state, batch['image'], batch['label'])

没有可变状态的训练步骤简化为

@jax.jit
def train_step(state, inputs, labels):

  def loss_fn(params):
    outputs = state.apply_fn({'params': params}, inputs)
    loss = xent_loss(outputs, labels)
    return loss

  loss, grads = jax.value_and_grad(loss_fn)(state.params)
  new_state = state.update(grads=grads)

  return new_state, loss


state = flax.training.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx,
)
for batch in ds.as_numpy_iterator():
  state, loss = train_step(state, batch['image'], batch['label'])

备注

  • 在 Flax 训练循环中,有一个 TrainState 数据类在每一步后用新状态更新是一个常见模式。

  • flax.training.train_state 中提出的简单解决方案可以扩展为包含更多数据,但高级用例(例如多个不同的模型和/或优化器)不受支持。用户应该改而对数据类进行分支,并根据自己的需要重新实现它。

  • 先前 API 中的 Optimizer 抽象不同,TrainState 现在直接包含 .params,无需再通过 .optimizer

先前 API#

优化器和 OptimizerDef#

优化器本身将通过创建一个从 OpimizerDef 派生的新类来实现

# flax/optim/momentum.py

@flax.struct.dataclass
class _MomentumHyperParams:
  learning_rate: jnp.ndarray
  beta: jnp.ndarray


@flax.struct.dataclass
class _MomentumParamState:
  momentum: np.ndarray


class Momentum(flax.optim.OptimizerDef):

  def __init__(self, learning_rate=None, beta=0.9):
    super().__init__(
      _MomentumHyperParams(learning_rate, beta)
    )

  def init_param_state(self, param):
    return _MomentumParamState(jnp.zeros_like(param))

  def apply_param_gradient(self, step, hyper_params, param, state, grad):
    del step
    assert hyper_params.learning_rate is not None
    new_momentum = state.momentum * hyper_params.beta + grad
    new_params = param - hyper_params.learning_rate * new_momentum
    return new_params, _MomentumParamState(new_momentum)

备注

  • 注意 OptimizerDefOptimizer 之间的关系:当用户代码调用函数 Optimizer.apply_gradient() 时,它会调用 OptimizerDef.apply_gradient()(以及其他操作),进而调用 OptimizerDef.apply_param_gradient()(由 OptimizerDef 的子类实现)。

  • 函数 init_param_state()apply_param_gradient() 会对 params/grads pytree 中的每个叶节点调用。这使得可以直接编写计算,无需使用 jax.tree_util.tree_map()

  • 在 Linen 之前,接口的定义没有考虑 paramsvariables 中其他集合的区别。最初的 API 非常优雅,因为只需要传递优化器,其中包含参数、优化器状态、优化器超参数以及对 OptimizerDef 的引用即可执行参数/状态更新。

之前的训练步骤#

优化器首先会根据其定义和目标参数的 pytree 进行构建

optimizer_def = flax.optim.Momentum(learning_rate=0.1, beta=0.9)
optimizer = optimizer_def.create(variables['params'])

然后,目标变量会在训练步骤中被优化(假设只有一个非参数集合“batch_stats”)

def make_train_step(apply_fn):
  @jax.jit
  def train_step(optimizer, batch_stats, inputs, labels):

    def loss_fn(params):
      variables = {'params': params, 'batch_stats': batch_stats}
      logits, new_model_state = apply_fn(
          variables, inputs, mutable=['batch_stats'])
      loss = xent_loss(logits, labels)
      return loss, new_model_state['batch_stats']

    (loss, new_batch_stats), grad = jax.value_and_grad(loss_fn, has_aux=True)(
        optimizer.target)
    lr = get_learning_rate(step)
    new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
    return new_optimizer, new_batch_stats, loss

  return train_step


batch_stats = variables['batch_stats']
train_step = make_train_step(model.apply)
for step, batch in enumerate(ds)
  optimizer, batch_stats, loss = train_step(
      optimizer, batch_stats, batch['image'], batch['label'])

备注

  • 注意,optimizer.apply_gradient() 可以接收额外的参数来更新超参数,例如在本例中来自独立函数 get_learning_rate() 的学习率。

更新计划#

  1. 完成关于此 FLIP 的讨论

  2. 为 Optax 添加 等价性测试,确保现有的 flax.optim 优化器与相应的 optax 优化器返回相同的值。

  3. 更新示例以使用 Optax,并验证它们在相同的计算成本下达到了相同的最终性能。

  4. 将 Optax 中缺少的优化器移植(例如 Adafactor),并验证上述要点。

  5. 更新所有文档(包括 README、Flax 指南、HOWTO 等),专门介绍 Optax 优化器。

  6. 创建一个从 flax.optim 迁移到使用 Optax 的过渡指南。该指南还应指向 Optax 的 等价性测试 和更新示例的 pull 请求。

  7. flax.optim 中的优化器标记为已弃用。

注意,所有当前的 Flax 示例都使用 Optax 中已有的优化器

示例

Flax

Optax

评论

imagenet

optim.Momentum

optax.sgd

DynamicScale 可以保持不变。

mnist

optim.Momentum

optax.sgd

nlp_seq

optim.Adam

optax.adamw

pixelcnn

optim.Adam

optax.adam

ppo

optim.Adam

optax.adam

seq2seq

optim.Adam

optax.adam

vae

optim.Adam

optax.adam

wmt

optim.Adam

optax.adamw

(Flax 的 Adam 实现有一个可选参数用于权重衰减,但在 Optax 中,带权重衰减和不带权重衰减的 Adam 是两个不同的别名。)

附录#

设置代码#

以下设置代码可用于运行此 FLIP 中的代码段

import functools
from typing import Callable, Sequence

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import tensorflow as tf
import tensorflow_datasets as tfds


def pp(features):
  return {
      'image': tf.cast(features['image'], tf.float32) / 255 - 0.5,
      'label': features['label'],
  }


class Model(nn.Module):

  @nn.compact
  def __call__(self, inputs):
    x = inputs.reshape([inputs.shape[0], -1])
    x = nn.normalization.BatchNorm(True)(x)
    x = nn.Dense(10)(x)
    x = nn.log_softmax(x)
    return x


def onehot(labels, num_classes, on_value=1.0, off_value=0.0):
  x = (labels[..., None] == jnp.arange(num_classes)[None])
  x = jax.lax.select(
      x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value))
  return x.astype(jnp.float32)


def xent_loss(logits, labels):
  return -jnp.sum(
      onehot(labels, num_classes=10) * logits) / labels.size


def get_learning_rate(step):
  return 0.1


model = Model()
rng = jax.random.key(0)
ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16)
batch = next(iter(ds))
variables = model.init(rng, jnp.array(batch['image'][:1]))
jax.tree_util.tree_map(jnp.shape, variables)