迁移学习#

本指南演示了使用 Flax 进行迁移学习工作流的各个部分。根据任务的不同,预训练模型可以仅用作特征提取器,也可以作为更大模型的一部分进行微调。

本指南演示了如何

  • 从 HuggingFace Transformers 加载预训练模型并从该预训练模型中提取特定的子模块。

  • 创建分类器模型。

  • 将预训练参数迁移到新的模型结构。

  • 使用 Optax 为模型的不同部分创建单独训练的优化器。

  • 设置模型进行训练。

性能说明

根据您的任务,本指南中的某些内容可能不是最佳的。例如,如果您只打算在预训练模型之上训练线性分类器,那么最好只提取一次特征嵌入,这可以使训练速度快得多,并且您可以使用针对线性回归或逻辑分类的专用算法。本指南展示了如何使用所有模型参数进行迁移学习。


设置#

# Note that the Transformers library doesn't use the latest Flax version.
! pip install -q "transformers[flax]"
# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,
# visit https://github.com/google/jax#installation.
! pip install -U -q flax jax jaxlib

创建模型加载函数#

要加载预训练的分类器,为了方便起见,首先创建一个返回 Flax Module 及其预训练变量的函数。

在下面的代码中,load_model 函数使用 HuggingFace 的 FlaxCLIPVisionModel 模型(来自 Transformers 库)并提取 FlaxCLIPModule 模块。

%%capture
from IPython.display import clear_output
from transformers import FlaxCLIPModel

# Note: FlaxCLIPModel is not a Flax Module
def load_model():
  clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')
  clear_output(wait=False) # Clear the loading messages
  module = clip.module # Extract the Flax Module
  variables = {'params': clip.params} # Extract the parameters
  return module, variables

请注意,FlaxCLIPVisionModel 本身不是 Flax Module,这就是我们需要执行此额外步骤的原因。

提取子模块#

从上面的代码片段中调用 load_model 会返回 FlaxCLIPModule,它由 text_modelvision_model 子模块组成。

提取在 .setup() 中定义的 vision_model 子模块及其变量的一种简单方法是使用 flax.linen.Module.bindclip 模块上,紧随其后的是在 vision_model 子模块上使用 flax.linen.Module.unbind

import flax.linen as nn

clip, clip_variables = load_model()
vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind()

创建分类器#

要创建分类器,请定义一个新的 Flax Module,它包含 backbone(预训练的视觉模型)和 head(分类器)子模块。

from typing import Callable
import jax.numpy as jnp
import jax

class Classifier(nn.Module):
  num_classes: int
  backbone: nn.Module
  

  @nn.compact
  def __call__(self, x):
    x = self.backbone(x).pooler_output
    x = nn.Dense(
      self.num_classes, name='head', kernel_init=nn.zeros)(x)
    return x

要构建分类器 model,将 vision_model 模块作为 backbone 传递给 Classifier。然后,可以通过传递用于推断参数形状的假数据来随机初始化模型的 params

num_classes = 3
model = Classifier(num_classes=num_classes, backbone=vision_model)

x = jnp.empty((1, 224, 224, 3))
variables = model.init(jax.random.key(1), x)
params = variables['params']

迁移参数#

由于 params 目前是随机的,因此必须将来自 vision_model_vars 的预训练参数迁移到 params 结构的适当位置(即 backbone)。

params['backbone'] = vision_model_vars['params']

注意:如果模型包含其他变量集合,例如 batch_stats,则也必须迁移这些集合。

优化#

如果您需要分别训练模型的不同部分,您有三个选择

  1. 使用 stop_gradient

  2. jax.grad 过滤参数。

  3. 为不同的参数使用多个优化器。

对于大多数情况,我们建议通过 Optaxmulti_transform 使用多个优化器,因为它既高效又可以轻松扩展以实现许多微调策略。

optax.multi_transform#

要使用 optax.multi_transform,必须定义以下内容

  1. 参数分区。

  2. 分区与其优化器之间的映射。

  3. 具有与参数相同形状的 pytree,但其叶子包含相应的分区标签。

要使用 optax.multi_transform 冻结上述模型中的层,可以使用以下设置

  • 定义 trainablefrozen 参数分区。

  • 对于 trainable 参数,选择 Adam (optax.adam) 优化器。

  • 对于 frozen 参数,选择 optax.set_to_zero 优化器。此虚拟优化器会将梯度清零,因此不会进行训练。

  • 使用 flax.traverse_util.path_aware_map 将参数映射到分区,将来自 backbone 的叶子标记为 frozen,并将其余部分标记为 trainable

from flax import traverse_util
import optax

partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = traverse_util.path_aware_map(
  lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)
tx = optax.multi_transform(partition_optimizers, param_partitions)

# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))
FrozenDict({
    backbone: {
        embeddings: {
            class_embedding: 'frozen',
            patch_embedding: {
                kernel: 'frozen',
            },
        },
    },
    head: {
        bias: 'trainable',
        kernel: 'trainable',
    },
})

要实现 差异学习率,可以使用任何其他优化器替换 optax.set_to_zero,可以根据任务选择不同的优化器和分区方案。有关高级优化器的更多信息,请参阅 Optax 的 组合优化器 文档。

创建 TrainState#

定义完模块、参数和优化器后,就可以像往常一样构建 TrainState

from flax.training.train_state import TrainState

state = TrainState.create(
  apply_fn=model.apply,
  params=params,
  tx=tx)

由于优化器负责冻结或微调策略,因此 train_step 不需要任何额外更改,可以正常进行训练。