在多个设备上进行集成

在多个设备上进行集成#

我们展示了如何在 MNIST 数据集上训练一个 CNN 集成,其中集成的规模等于可用设备的数量。简而言之,这个改变可以描述为

  • 使用 jax.pmap() 将多个函数并行化,

  • 分割随机种子以获得不同的参数初始化,

  • 复制输入并在必要时解除复制输出,

  • 对来自不同设备的概率进行平均以计算预测值。

在这个 HOWTO 中,我们省略了一些代码,例如导入、CNN 模块和度量计算,但是可以在 MNIST 示例 中找到它们。

并行函数#

我们首先创建一个 create_train_state() 的并行版本,它检索模型的初始参数。我们使用 jax.pmap() 来实现这一点。“pmapping”一个函数的效果是它将使用 XLA 编译该函数(类似于 jax.jit()),但在 XLA 设备(例如 GPU/TPU)上并行执行它。

def create_train_state(rng, learning_rate, momentum):
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

请注意,对于上面的单模型代码,我们使用 jax.jit() 来延迟初始化模型(有关更多详细信息,请参阅 Module.init 的文档)。对于集成情况,jax.pmap() 默认情况下将映射提供的参数 rng 的第一个轴,因此我们应该确保在稍后调用此函数时,为每个设备提供不同的值。

还要注意我们如何指定 learning_ratemomentum 是静态参数,这意味着将使用这些参数的具体值,而不是抽象形状。这是必要的,因为提供的参数将是标量值。有关更多详细信息,请参阅 JIT 机制:跟踪和静态变量

接下来,我们只需对 apply_model()update_model() 函数做同样的事情。为了计算集成中的预测值,我们对各个概率进行平均。我们使用 jax.lax.pmean() 来计算跨设备的平均值。这也要求我们为 jax.pmap()jax.lax.pmean() 指定 axis_name

@jax.jit
def apply_model(state, images, labels):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)

  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy

@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
  accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
  return grads, loss, accuracy

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

训练集成#

接下来,我们转换 train_epoch() 函数。在调用上面的 pmapped 函数时,我们主要需要注意在必要时复制所有设备的参数,并解除复制返回值。

def train_epoch(state, train_ds, batch_size, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = train_ds['image'][perm, ...]
    batch_labels = train_ds['label'][perm, ...]
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy
def train_epoch(state, train_ds, batch_size, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = jax_utils.replicate(train_ds['image'][perm, ...])
    batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(jax_utils.unreplicate(loss))
    epoch_accuracy.append(jax_utils.unreplicate(accuracy))
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy

如您所见,我们不必对 state 周围的逻辑进行任何更改。这是因为,正如我们将在下面的训练代码中看到的那样,训练状态已经被复制,因此当我们将其传递给 train_step() 时,事情将正常工作,因为 train_step() 是 pmapped 的。但是,训练数据集还没有被复制,因此我们在这里进行复制。由于复制整个训练数据集会消耗大量内存,因此我们在批次级别进行复制。

现在我们可以重写实际的训练逻辑。这包括两个简单的更改:确保在将 RNG 传递给 create_train_state() 时复制 RNG,并复制测试数据集,测试数据集比训练数据集小得多,因此我们可以直接对整个数据集进行复制。

train_ds, test_ds = get_datasets()

rng = jax.random.key(0)

rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, learning_rate, momentum)


for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss, train_accuracy = train_epoch(
      state, train_ds, batch_size, input_rng)

  _, test_loss, test_accuracy = apply_model(
      state, test_ds['image'], test_ds['label'])

  logging.info(
      'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
      'test_loss: %.4f, test_accuracy: %.2f'
      % (epoch, train_loss, train_accuracy * 100, test_loss,
         test_accuracy * 100))
train_ds, test_ds = get_datasets()
test_ds = jax_utils.replicate(test_ds)
rng = jax.random.key(0)

rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()),
                           learning_rate, momentum)

for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss, train_accuracy = train_epoch(
      state, train_ds, batch_size, input_rng)

  _, test_loss, test_accuracy = jax_utils.unreplicate(
      apply_model(state, test_ds['image'], test_ds['label']))

  logging.info(
      'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
      'test_loss: %.4f, test_accuracy: %.2f'
      % (epoch, train_loss, train_accuracy * 100, test_loss,
         test_accuracy * 100))