迁移学习#
本指南演示了使用 Flax 进行迁移学习工作流的各个部分。根据任务的不同,预训练模型可以仅用作特征提取器,也可以作为更大模型的一部分进行微调。
本指南演示了如何
从 HuggingFace Transformers 加载预训练模型并从该预训练模型中提取特定的子模块。
创建分类器模型。
将预训练参数迁移到新的模型结构。
使用 Optax 为模型的不同部分创建单独训练的优化器。
设置模型进行训练。
性能说明
根据您的任务,本指南中的某些内容可能不是最佳的。例如,如果您只打算在预训练模型之上训练线性分类器,那么最好只提取一次特征嵌入,这可以使训练速度快得多,并且您可以使用针对线性回归或逻辑分类的专用算法。本指南展示了如何使用所有模型参数进行迁移学习。
设置#
# Note that the Transformers library doesn't use the latest Flax version.
! pip install -q "transformers[flax]"
# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,
# visit https://github.com/google/jax#installation.
! pip install -U -q flax jax jaxlib
创建模型加载函数#
要加载预训练的分类器,为了方便起见,首先创建一个返回 Flax Module
及其预训练变量的函数。
在下面的代码中,load_model
函数使用 HuggingFace 的 FlaxCLIPVisionModel
模型(来自 Transformers 库)并提取 FlaxCLIPModule
模块。
%%capture
from IPython.display import clear_output
from transformers import FlaxCLIPModel
# Note: FlaxCLIPModel is not a Flax Module
def load_model():
clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')
clear_output(wait=False) # Clear the loading messages
module = clip.module # Extract the Flax Module
variables = {'params': clip.params} # Extract the parameters
return module, variables
请注意,FlaxCLIPVisionModel
本身不是 Flax Module
,这就是我们需要执行此额外步骤的原因。
提取子模块#
从上面的代码片段中调用 load_model
会返回 FlaxCLIPModule
,它由 text_model
和 vision_model
子模块组成。
提取在 .setup()
中定义的 vision_model
子模块及其变量的一种简单方法是使用 flax.linen.Module.bind
在 clip
模块上,紧随其后的是在 vision_model
子模块上使用 flax.linen.Module.unbind
。
import flax.linen as nn
clip, clip_variables = load_model()
vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind()
创建分类器#
要创建分类器,请定义一个新的 Flax Module
,它包含 backbone
(预训练的视觉模型)和 head
(分类器)子模块。
from typing import Callable
import jax.numpy as jnp
import jax
class Classifier(nn.Module):
num_classes: int
backbone: nn.Module
@nn.compact
def __call__(self, x):
x = self.backbone(x).pooler_output
x = nn.Dense(
self.num_classes, name='head', kernel_init=nn.zeros)(x)
return x
要构建分类器 model
,将 vision_model
模块作为 backbone
传递给 Classifier
。然后,可以通过传递用于推断参数形状的假数据来随机初始化模型的 params
。
num_classes = 3
model = Classifier(num_classes=num_classes, backbone=vision_model)
x = jnp.empty((1, 224, 224, 3))
variables = model.init(jax.random.key(1), x)
params = variables['params']
迁移参数#
由于 params
目前是随机的,因此必须将来自 vision_model_vars
的预训练参数迁移到 params
结构的适当位置(即 backbone
)。
params['backbone'] = vision_model_vars['params']
注意:如果模型包含其他变量集合,例如 batch_stats
,则也必须迁移这些集合。
优化#
如果您需要分别训练模型的不同部分,您有三个选择
使用
stop_gradient
。为
jax.grad
过滤参数。为不同的参数使用多个优化器。
对于大多数情况,我们建议通过 Optax 的 multi_transform
使用多个优化器,因为它既高效又可以轻松扩展以实现许多微调策略。
optax.multi_transform#
要使用 optax.multi_transform
,必须定义以下内容
参数分区。
分区与其优化器之间的映射。
具有与参数相同形状的 pytree,但其叶子包含相应的分区标签。
要使用 optax.multi_transform
冻结上述模型中的层,可以使用以下设置
定义
trainable
和frozen
参数分区。对于
trainable
参数,选择 Adam (optax.adam
) 优化器。
对于
frozen
参数,选择optax.set_to_zero
优化器。此虚拟优化器会将梯度清零,因此不会进行训练。使用
flax.traverse_util.path_aware_map
将参数映射到分区,将来自backbone
的叶子标记为frozen
,并将其余部分标记为trainable
。
from flax import traverse_util
import optax
partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = traverse_util.path_aware_map(
lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)
tx = optax.multi_transform(partition_optimizers, param_partitions)
# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))
FrozenDict({
backbone: {
embeddings: {
class_embedding: 'frozen',
patch_embedding: {
kernel: 'frozen',
},
},
},
head: {
bias: 'trainable',
kernel: 'trainable',
},
})
要实现 差异学习率,可以使用任何其他优化器替换 optax.set_to_zero
,可以根据任务选择不同的优化器和分区方案。有关高级优化器的更多信息,请参阅 Optax 的 组合优化器 文档。
创建 TrainState
#
定义完模块、参数和优化器后,就可以像往常一样构建 TrainState
。
from flax.training.train_state import TrainState
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx)
由于优化器负责冻结或微调策略,因此 train_step
不需要任何额外更改,可以正常进行训练。