处理整个数据集#

出于效率原因,我们形成了包含多个示例的批次,并并行处理它们。尤其是在评估模型时,重要的是要处理所有示例,并 **避免丢失** 末尾无法形成完整批次的剩余示例。

问题#

在单个设备上评估时,可以丢弃最后一个不完整的批次,也可以形成一个形状与前面批次不同的最后一个批次。后者的缺点是会触发 eval_step() 的 **重新编译**,因为 XLA 不是形状多态的。

collections.Counter(
    tuple(batch['image'].shape)
    for batch in tfds.load('mnist', split='test').batch(per_device_batch_size)
)
# output:
# Counter({(272, 28, 28, 1): 1, (512, 28, 28, 1): 19})

使用多个设备进行数据并行时,问题会更加突出。如果批次大小 **不能被设备数量整除**,那么最后一步必须在单个设备(或设备子集)上执行。通常会丢弃最后一个批次,但这会导致结果不正确。

sum(
    np.prod(batch['label'].shape)
    for batch in tfds.load('mnist', split='test')
        .batch(per_device_batch_size, drop_remainder=True)
        .batch(jax.local_device_count())
)
# output:
# 9728

使用多个主机进一步复杂化了情况,因为 JAX 使用 SPMD 范式,每个主机都必须执行相同的程序。我们通常会使用 tfds.split_for_jax_process() 为不同的主机形成非重叠分割,但这会导致 **不同主机上的数量不同**,从而导致所有示例都将被处理时, JAX 程序不同。

process_count = 6
[
    len(tfds.load(dataset_name, split=tfds.split_for_jax_process(
        'test', process_index=process_index, process_count=process_count)))
    for process_index in range(process_count)
]
# output:
# [1667, 1667, 1667, 1667, 1666, 1666]

解决方案:填充#

尽管可以通过巧妙地调整不同主机上不同设备执行的批次数量来解决此问题,但这种解决方案很快就变得很复杂,并且使主评估循环难以阅读,因为有很多繁琐的逻辑。

解决此问题的更直接的方法是在数据集的末尾使用填充,以确保最后一个批次与前面批次的大小相同。

手动实现#

最后一个批次手动填充,使其包含与前面批次中相同的示例数量。填充示例的预测将从计算中丢弃。

shard = lambda x: einops.rearrange(
    x, '(d b) ... -> d b ...', d=jax.local_device_count())
unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...')

correct = total = 0
for batch in ds.as_numpy_iterator():
  images = batch['image']
  n = len(images)
  padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype)
  padded_images = np.concatenate([images, padding])
  preds = unshard(get_preds(variables, shard(padded_images)))[:n]
  total += n
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

使用 pad_shard_unpad()#

上述模式,即 pad→shard→predict→unshard→unpad 序列,可以提取到一个实用程序包装器 pad_shard_unpad() 中,这极大地简化了上述评估循环。

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = flax.jax_utils.pad_shard_unpad(get_preds)(
      vs, batch['image'], min_device_batch=per_device_batch_size)
  total += len(batch['image'])
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

eval_step() 中计算指标#

我们通常不希望返回预测并在主评估循环中计算指标,而是希望将指标计算作为评估步骤的一部分,尤其是在使用 clu.metricsclu.metrics 等库时。

在这种情况下,我们希望将指标作为 static_argnums 传递(即不进行分片/填充),并将返回值也视为 static_return(即不进行取消分片或取消填充)。

def eval_step(metrics, variables, batch):
  print('retrigger compilation', {k: v.shape for k, v in batch.items()})
  preds = model.apply(variables, batch['image'])
  correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()
  total = batch['mask'].sum()
  return dict(
      correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),
      total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),
  )

eval_step = jax.pmap(eval_step, axis_name='batch')
eval_step = flax.jax_utils.pad_shard_unpad(
    eval_step, static_argnums=(0, 1), static_return=True)

添加“无限填充”#

上述解决方案在大多数情况下都有效,但它有一些限制

  1. 在罕见的情况下,即使在多个主机上对数据集进行均匀分割也会导致批次数量不同。假设有一个包含 n=4097 个示例的数据集,并在 h=8 上进行评估,每个主机都拥有 d=8 个本地设备,并形成 b=128 的设备级批次大小。如果使用均匀数据集分割,第一个主机将获得 4096/8+1==513 个示例,而所有其他主机将获得 4096/8==512 个示例。如果形成 d*b==512 的每个主机批次,这将导致第一个主机有两个批次,而所有其他主机只有一个批次,从而违反 SPMD 原则,并在最后一个 psum() 指令中挂起多主机设置(该指令将仅由第一个主机执行,而不会由其他主机执行)。

  2. 当通过使用 ds.filter() 动态丢弃示例时。

在这些更复杂的情况下,我们可以独立地向每个主机添加“无限填充”,并继续处理示例,直到 *所有* 主机都用完未填充的示例。

correct = total = 0
for batch in ds.as_numpy_iterator():
  n = count_p(batch['mask'])[0].item()  # adds sync barrier
  if not n: break

  preds = get_preds(vs, batch['image']).argmax(axis=-1)
  total += n
  correct += count_correct_p(batch['label'], preds, batch['mask'])[0]