在多个设备上进行集成#
我们展示了如何在 MNIST 数据集上训练一个 CNN 集成,其中集成的规模等于可用设备的数量。简而言之,这个改变可以描述为
使用
jax.pmap()
将多个函数并行化,分割随机种子以获得不同的参数初始化,
复制输入并在必要时解除复制输出,
对来自不同设备的概率进行平均以计算预测值。
在这个 HOWTO 中,我们省略了一些代码,例如导入、CNN 模块和度量计算,但是可以在 MNIST 示例 中找到它们。
并行函数#
我们首先创建一个 create_train_state()
的并行版本,它检索模型的初始参数。我们使用 jax.pmap()
来实现这一点。“pmapping”一个函数的效果是它将使用 XLA 编译该函数(类似于 jax.jit()
),但在 XLA 设备(例如 GPU/TPU)上并行执行它。
def create_train_state(rng, learning_rate, momentum):
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
请注意,对于上面的单模型代码,我们使用 jax.jit()
来延迟初始化模型(有关更多详细信息,请参阅 Module.init 的文档)。对于集成情况,jax.pmap()
默认情况下将映射提供的参数 rng
的第一个轴,因此我们应该确保在稍后调用此函数时,为每个设备提供不同的值。
还要注意我们如何指定 learning_rate
和 momentum
是静态参数,这意味着将使用这些参数的具体值,而不是抽象形状。这是必要的,因为提供的参数将是标量值。有关更多详细信息,请参阅 JIT 机制:跟踪和静态变量。
接下来,我们只需对 apply_model()
和 update_model()
函数做同样的事情。为了计算集成中的预测值,我们对各个概率进行平均。我们使用 jax.lax.pmean()
来计算跨设备的平均值。这也要求我们为 jax.pmap()
和 jax.lax.pmean()
指定 axis_name
。
@jax.jit
def apply_model(state, images, labels):
def loss_fn(params):
logits = CNN().apply({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
def loss_fn(params):
logits = CNN().apply({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
return grads, loss, accuracy
@jax.pmap
def update_model(state, grads):
return state.apply_gradients(grads=grads)
训练集成#
接下来,我们转换 train_epoch()
函数。在调用上面的 pmapped 函数时,我们主要需要注意在必要时复制所有设备的参数,并解除复制返回值。
def train_epoch(state, train_ds, batch_size, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
def train_epoch(state, train_ds, batch_size, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = jax_utils.replicate(train_ds['image'][perm, ...])
batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(jax_utils.unreplicate(loss))
epoch_accuracy.append(jax_utils.unreplicate(accuracy))
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
如您所见,我们不必对 state
周围的逻辑进行任何更改。这是因为,正如我们将在下面的训练代码中看到的那样,训练状态已经被复制,因此当我们将其传递给 train_step()
时,事情将正常工作,因为 train_step()
是 pmapped 的。但是,训练数据集还没有被复制,因此我们在这里进行复制。由于复制整个训练数据集会消耗大量内存,因此我们在批次级别进行复制。
现在我们可以重写实际的训练逻辑。这包括两个简单的更改:确保在将 RNG 传递给 create_train_state()
时复制 RNG,并复制测试数据集,测试数据集比训练数据集小得多,因此我们可以直接对整个数据集进行复制。
train_ds, test_ds = get_datasets()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, learning_rate, momentum)
for epoch in range(1, num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(
state, train_ds, batch_size, input_rng)
_, test_loss, test_accuracy = apply_model(
state, test_ds['image'], test_ds['label'])
logging.info(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
'test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
test_accuracy * 100))
train_ds, test_ds = get_datasets()
test_ds = jax_utils.replicate(test_ds)
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()),
learning_rate, momentum)
for epoch in range(1, num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(
state, train_ds, batch_size, input_rng)
_, test_loss, test_accuracy = jax_utils.unreplicate(
apply_model(state, test_ds['image'], test_ds['label']))
logging.info(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
'test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
test_accuracy * 100))