模型手术#

模型手术是指对现有神经网络的构建块和参数进行修改的行为,例如层替换、参数或状态操作,甚至“猴子补丁”。在本指南中,您将学习如何在 Flax NNX 中使用几种真实场景执行模型手术

  • Pythonic nnx.Module 操作:使用 Pythonic 的方式操作给定模型的子 Module

  • 操作抽象模型或状态:用于操作 flax.nnx.Module 和状态而不分配内存的关键技巧。

  • 从原始状态到模型的检查点手术:当参数状态与现有模型代码不兼容时,如何操作参数状态。

  • 部分初始化:如何使用朴素方法或内存高效的方法从头开始仅初始化模型的一部分。

from typing import *
from pprint import pprint
import functools

import jax
from jax import lax, numpy as jnp, tree_util as jtu

from jax.sharding import PartitionSpec, Mesh, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import flax.traverse_util
import numpy as np
import orbax.checkpoint as orbax

key = jax.random.key(0)
class TwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

Pythonic nnx.Module 操作#

当满足以下条件时,更容易执行模型手术:

  1. 您已经有一个完全成熟的模型,加载了正确的参数;并且

  2. 您不打算更改模型定义代码。

您可以对其子 Module 执行各种 Pythonic 操作,例如子 Module 交换、Module 共享、变量共享和猴子补丁

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))

# Sub-`Module` swapping.
original1, original2 = model.linear1, model.linear2
model.linear1, model.linear2 = model.linear2, model.linear1
np.testing.assert_allclose(model(x), original1(original2(x)))

# `Module` sharing (tying all weights together).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear2 = model.linear1
assert not hasattr(nnx.state(model), 'linear2')
np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))

# Variable sharing (weight-tying).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel  # the bias parameter is kept separate
assert hasattr(nnx.state(model), 'linear2')
assert hasattr(nnx.state(model)['linear2'], 'bias')
assert not hasattr(nnx.state(model)['linear2'], 'kernel')

# Monkey-patching.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
def awesome_layer(x): return x
model.linear2 = awesome_layer
np.testing.assert_allclose(model(x), model.linear1(x))

创建不分配内存的抽象模型或状态#

要进行更复杂的模型手术,您可以使用的关键技术是创建和操作抽象模型或状态,而不分配任何实际的参数数据。这可以加快试迭代,并消除对内存约束的任何担忧。

要创建抽象模型

  • 创建一个返回有效 Flax NNX 模型的函数;并且

  • 在其上运行 nnx.eval_shape (而不是 jax.eval_shape)。

现在,您可以像往常一样使用 nnx.split 来获取其抽象状态。请注意,真实模型中所有应该是 jax.Array 的字段现在都是抽象的 jax.ShapeDtypeStruct 类型,仅包含形状/数据类型/分片信息。

abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
gdef, abs_state = nnx.split(abs_model)
pprint(abs_state)
State({
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  },
  'linear2': {
    'bias': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  }
})

当您使用真实的 jax.Array 填充每个 nnx.VariableState pytree 叶子的 value 属性时,抽象模型就等同于真实模型。

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
abs_state['linear1']['kernel'].value = model.linear1.kernel
abs_state['linear1']['bias'].value = model.linear1.bias
abs_state['linear2']['kernel'].value = model.linear2.kernel
abs_state['linear2']['bias'].value = model.linear2.bias
nnx.update(abs_model, abs_state)
np.testing.assert_allclose(abs_model(x), model(x))  # They are equivalent now!

检查点手术#

掌握了抽象状态技术后,您可以对任何检查点(或运行时参数 pytree)执行任意操作,使其与给定的模型代码相匹配,然后调用 nnx.update 来合并它们。

如果您尝试大幅更改模型代码(例如,从 Flax Linen 迁移到 Flax NNX)并且旧权重不再自然兼容,这将非常有用。

让我们在这里运行一个简单的例子

# Save a version of model into a checkpoint
checkpointer = orbax.PyTreeCheckpointer()
old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)

在这个新模型中,子 Module 的名称从 linear(1|2) 重命名为 layer(1|2)。由于 pytree 结构已更改,因此无法使用新的模型状态结构直接加载旧检查点

class ModifiedTwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(dim, dim, rngs=rngs)  # no longer linear1!
    self.layer2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.layer1(x)
    return self.layer2(x)

abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
try:
  with_item = checkpointer.restore('/tmp/nnx-surgery-state', item=nnx.state(abs_model))
  print(with_item)
except Exception as e:
  print(f'This will throw error: {type(e)}: {e}')
This will throw error: <class 'ValueError'>: Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.

但是,您可以将参数 pytree 加载为原始字典,执行重命名,并生成一个保证与您的新模型定义兼容的新状态。

def process_raw_dict(raw_state_dict):
  flattened = nnx.traversals.flatten_mapping(raw_state_dict)
  # Cut the '.value' postfix on every leaf path.
  flattened = {(path[:-1] if path[-1] == 'value' else path): value
               for path, value in flattened.items()}
  return nnx.traversals.unflatten_mapping(flattened)

# Make your local change on the checkpoint dictionary.
raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw_dict)
raw_dict['layer1'] = raw_dict.pop('linear1')
raw_dict['layer2'] = raw_dict.pop('linear2')

# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
state.replace_by_pure_dict(process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)

np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'value': Array([[-0.80345297, -0.34071913, -0.9408296 ,  0.01005968],
       [ 0.26146442,  1.1247735 ,  0.54563737, -0.374164  ],
       [ 1.0281805 , -0.6798804 , -0.1488401 ,  0.05694951],
       [-0.44308168, -0.60587114,  0.434087  , -0.40541083]],      dtype=float32)}},
 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'value': Array([[ 0.21010089,  0.8289361 ,  0.04589564,  0.5422644 ],
       [ 0.41914317,  0.84359694, -0.47937787, -0.49135214],
       [-0.46072108,  0.4630125 ,  0.39276958, -0.9441406 ],
       [-0.6690758 , -0.18474789, -0.57622856,  0.4821079 ]],      dtype=float32)}}}
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1136: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(

部分初始化#

在某些情况下(例如使用 LoRA(低秩适应)),您可能只想随机初始化模型参数的一部分。这可以通过以下方式实现:

  • 朴素的部分初始化;或

  • 内存高效的部分初始化。

朴素的部分初始化#

要进行朴素的部分初始化,您可以只初始化整个模型,然后交换预训练的参数。但是,如果您的修改需要重新创建以后将要丢弃的模块参数,则此方法可能会在中间分配额外的内存。下面是一个示例。

注意:您可以使用 jax.live_arrays() 来检查任何给定时间内存中存在的所有数组。当您多次运行单个 Jupyter notebook 单元格时,此调用可能会“混乱”(由于旧 Python 变量的垃圾回收)。但是,重新启动 notebook 中的 Python 内核并从头开始运行代码将始终产生相同的输出。

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))
print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')
# In this line, extra kernel and bias is created inside the new LoRALinear!
# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.
simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))
print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'
      ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')
nnx.update(simple_model, old_state)
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 discarded - only lora_a & lora_b are used in model)')
Number of jax arrays in memory at start: 38
Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)
Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)

内存高效的部分初始化#

要进行内存高效的部分初始化,请使用 nnx.jit 的高效编译代码,以确保仅初始化您需要的状态参数

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
@nnx.jit(donate_argnums=0)
def partial_init(old_state, rngs):
  model = TwoLayerMLP(4, rngs=rngs)
  # Create a new state.
  model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)
  # Add the existing state.
  nnx.update(model, old_state)
  return model

print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')
# Note that `old_state` will be deleted after this `partial_init` call.
good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 new created - lora_a and lora_b)')
Number of JAX Arrays in memory at start: 44
Number of JAX Arrays in memory at end: 46 (2 new created - lora_a and lora_b)