使用 FP8 的用户指南#

JAX 支持多种 FP8 格式,包括 E4M3 (jnp.float8_e4m3fn) 和 E5M2 (jnp.float8_e5m2)。由于 FP8 数据类型的范围有限,因此必须对更高精度的数据进行缩放以使其适合 FP8 可表示的范围内,这一过程称为量化 (Q)。反之,反量化 (DQ) 将 FP8 数据重新缩放到其原始类型。

虽然 jnp.dot 支持 FP8 输入,但某些限制使其不适用于实际应用。或者,我们的编译器 XLA 可以识别以下模式->DQ->Dot,然后调用 FP8 后端(例如,GPU 的 cublasLt)。FLAX 将此类模式封装到 nn.fp8_ops.Fp8DotGeneralOp 模块中,允许用户轻松地为现有层(例如,nn.Dense)配置它。

本教程将引导您了解如何使用它的基本知识。

设置我们的环境#

在这里,我们提供设置笔记本环境所需的代码。此外,我们定义了一个函数来检查 XLA 优化的 HLO 是否确实在幕后调用了 FP8 点操作。

注意:本教程依赖于 XLA-FP8 功能,该功能仅在 NVIDIA Hopper GPU 或更高版本上受支持。

import flax
import jax
import re
import pprint
from jax import random
from jax import numpy as jnp
from jax._src import test_util as jtu
from flax import linen as nn
from flax.linen import fp8_ops

e4m3 = jnp.float8_e4m3fn
e5m2 = jnp.float8_e5m2
f32 = jnp.float32
E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)

assert jtu.is_cuda_compute_capability_at_least("9.0")

def check_fp8_call(lowered):
  hlo = lowered.compile()
  if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()):
    print("Fp8 call detected!")
  else:
    print("No Fp8 call!")

FLAX 低级 API#

JAX 点操作(例如 jnp.dot)支持 FP8 dtype 输入。因此,执行以下调用是合法的

key = random.key(0)
A = random.uniform(key, (16, 32))
B = random.uniform(key, (32, 64))
@jax.jit
def dot_fp8(A, B):
  return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)
check_fp8_call(dot_fp8.lower(A, B))

但是,这种方法有两个主要问题。首先,jnp.dot 不接受操作数的缩放因子,默认情况下缩放因子为 1.0。其次,它不支持混合 FP8 数据类型的操作数。例如,当操作数为 E5M2 和 E4M3 时,点积使用提升的 FP16 数据类型执行。

在实际场景中,必须指定缩放因子,无论是推断的校准还是训练期间用户定义的算法。此外,通常的做法是使用 E5M2 表示梯度,使用 E4M3 表示激活和内核。这些限制使得这种方法在实际应用中不太实用。

为了解决这些限制并创建更通用的 FP8 点积,我们建议利用 XLA-FP8。让我们从简单的缩放策略开始。

当前缩放#

缩放因子通常定义为 scale = amax(x) / MAX,其中 amax 是查找张量绝对最大值的运算,而 MAX 是目标 dtype 可表示范围内的最大值。这种缩放方法允许我们直接从点积的当前操作数张量中推导出缩放因子。

@jax.jit
def dot_fp8(A, B):
  A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX
  B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX
  A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)
  B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)

  C = jnp.dot(A, B)
  return C

C = dot_fp8(A, B)
check_fp8_call(dot_fp8.lower(A, B))

如代码所示,我们对点积的操作数执行假量化 (fp8_ops.quantize_dequantize)。虽然 jnp.dot 仍然处理更高精度的输入,但 XLA 会检测到此模式并将点操作重写为 FP8 点调用(例如,GPU 的 cublasLt 调用)。这种方法有效地模拟了第一个示例,但提供了更大的灵活性。我们可以控制输入 dtype(这里都设置为 E4M3,但我们可以使用混合的 E4M3 和 E5M2)并定义缩放因子,XLA 可以检测到并在点后端中使用。

当前缩放方法的一个主要问题是计算 A_scaleB_scale 所带来的开销,这需要额外加载操作数张量。为了克服这个问题,我们建议使用延迟缩放。

延迟缩放#

在延迟缩放中,我们使用与 amax 历史相关的缩放因子。缩放因子仍然是标量,但 amax 历史是一个列表,存储最近步骤的 amax 值(例如,1024 步)。这两个张量都是从之前的步骤计算出来的,并在模型参数中维护。

延迟缩放的假量化由 fp8_ops.in_qdq(针对激活和权重)和 fp8_ops.out_qdq(针对梯度)提供。

a_scale = jnp.array(1.0)
b_scale = jnp.array(1.0)
g_scale = jnp.array(1.0)
a_amax_hist = jnp.zeros((1024,))
b_amax_hist = jnp.zeros((1024,))
g_amax_hist = jnp.zeros((1024,))

@jax.jit
def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,
            g_scale, g_amax_hist):
  a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)
  b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)
  
  c = jnp.dot(a, b)
  c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)
  return c

C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
            g_scale, g_amax_hist)
check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
                             g_scale, g_amax_hist))

在此示例中,我们首先准备三对缩放因子和 amax 历史记录,并将它们视为从先前步骤计算的结果。然后,我们将 fp8_ops.in_qdq 应用于 jnp.dot 的输入操作数,然后将 fp8_ops.out_qdq 应用于 jnp.dot 的输出。请注意,fp8_ops.out_qdq 将通过 custom_vjp 函数对输出的梯度应用假量化。新的缩放因子和 amax 历史记录将通过它们的梯度返回,这将在下一节中介绍。

FLAX 高级 API#

使用 FLAX 库,将 FP8 操作合并到现有的 FLAX 层中是一个无缝的过程。用户无需操纵量化的低级 API。相反,他们可以使用简单的“代码注入”方法将提供的自定义 FP8 点 (fp8_ops.Fp8DotGeneralOp) 集成到 FLAX 层中。此自定义操作封装了所有与 FP8 相关的任务,包括量化-反量化操作的放置、更新缩放因子的算法以及选择 FP8 dtype 组合进行正向和反向传播。

考虑以下示例

model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)
params = model.init(key, A)

@jax.jit
def train_step(var, a): 
  c = model.apply(var, a)
  return jnp.sum(c)

check_fp8_call(train_step.lower(params, A))

在此示例中,我们只需设置 dot_general_cls=fp8_ops.Fp8DotGeneralOp 即可使 Dense 层利用 FP8 点操作。模型的使用方式与之前几乎相同。主要区别是添加了一类新参数:缩放因子和 amax 历史记录的集合。在下一节中,我们将探讨如何更新这些参数。

操纵 FP8 参数#

让我们首先检查 params 的数据结构。在下面的代码中,我们删除了参数值,然后显示 PyTree 结构。

params_structure = flax.core.unfreeze(params).copy()
params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')
for key, value in params_structure.items():
    params_structure[key] = '*'
params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')
pprint.pprint(params_structure)

输出如下

{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',
                                                    'input_scale': '*',
                                                    'kernel_amax_history': '*',
                                                    'kernel_scale': '*',
                                                    'output_grad_amax_history': '*',
                                                    'output_grad_scale': '*'}},
 'params': {'bias': '*', 'kernel': '*'}}

除了预期的 params 之外,还有一个名为 _overwrite_with_gradient 的附加类别。此类别包括三对 amax_historyscale,分别用于激活、内核和点梯度。

更新 FP8 参数的梯度#

现在,我们执行一步训练以获得梯度,并了解如何使用它们来更新参数。

step_fn = jax.jit(jax.grad(train_step, (0, 1)))

grads = step_fn(params, A)

params = flax.core.unfreeze(params)
params = flax.traverse_util.flatten_dict(params, sep='/')
grads = flax.traverse_util.flatten_dict(grads[0], sep='/')

for key, value in params.items():
  if key.startswith('params'):
    params[key] = value + 0.01 * grads[key]
  if key.startswith('_overwrite_with_gradient'):
    params[key] = grads[key]

params = flax.traverse_util.unflatten_dict(params, sep='/')
params = flax.core.freeze(params)

上面的代码演示了如何更新 params_overwrite_with_gradient。对于 params,我们使用公式 new_param = old_param + 0.01 * grads,其中 0.01 是学习率(或者用户可以使用 optax 中的任何优化器)。对于 _overwrite_with_gradient,我们只需使用梯度覆盖旧值。

请注意,flax.training.train_state.TrainState 方便地支持 _overwrite_with_gradient 类别,因此如果用户没有使用自定义 TrainState,则无需修改他们的脚本。

累积 FP8 参数的梯度#

当在分支方式中使用相同参数时,自动微分机制将添加来自这些分支的梯度。这在流水线并行等场景中很常见,其中每个微批次共享相同的参数集。但是,对于 _overwrite_with_gradient 参数,这种通过加法进行累加是没有意义的。相反,我们更喜欢通过取最大值进行自定义累加。

为了解决这个问题,我们引入了一个自定义数据类型 fp8_ops.fp32_max_grad。以下是基本用法示例

fmax32 = fp8_ops.fp32_max_grad

def reuse_fp8_param(x, y, scale, amax_history):
  scale = scale.astype(fmax32)
  amax_history = amax_history.astype(fmax32)

  x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)
  y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)
  return x + y

reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))
reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)

_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)
print(new_ah, new_sf)

在此示例中,我们首先将 scaleamax_history 转换为 fp8_ops.fp32_max_grad,然后使用相同的 scaleamax_historyfp8_ops.in_qdq 进行两次调用。在自动微分期间,将来自每个分支的梯度取为最大值,从而得到正确的结果

1.0 [3. 0. 0. ... 0. 0. 0.]

如果我们不进行类型转换,我们会得到以下结果,这意味着两个分支的梯度被加在一起

2.0 [5. 0. 0. ... 0. 0. 0.]

如果用户选择使用高级 API,则此转换已包含在内。