模型手术#

通常,Flax 模块和优化器会为您跟踪和更新参数。但是,有时您可能希望进行一些模型手术并自行调整参数张量。本指南向您展示如何完成此操作。

设置#

!pip install --upgrade -q pip jax jaxlib flax
import functools

import jax
import jax.numpy as jnp
from flax import traverse_util
from flax import linen as nn
from flax.core import freeze
import jax
import optax

使用 Flax 模块进行手术#

让我们为我们的演示创建一个小型卷积神经网络模型。

像往常一样,您可以运行 CNN.init(...)['params'] 以获取 params 并将其传递到训练的每个步骤中并进行修改。

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
      x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = nn.Conv(features=64, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = x.reshape((x.shape[0], -1))
      x = nn.Dense(features=256)(x)
      x = nn.relu(x)
      x = nn.Dense(features=10)(x)
      x = nn.log_softmax(x)
      return x

def get_initial_params(key):
    init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
    initial_params = CNN().init(key, init_shape)['params']
    return initial_params

key = jax.random.key(0)
params = get_initial_params(key)

jax.tree_util.tree_map(jnp.shape, params)

请注意,返回的 params 是一个 FrozenDict,其中包含一些作为内核和偏差的 JAX 数组。

FrozenDict 不过是一个只读字典,Flax 将其设为只读是因为 JAX 的函数特性:JAX 数组是不可变的,新的 params 需要替换旧的 params。使字典只读可确保在训练和更新期间不会意外发生字典的任何就地修改。

实际上在 Flax 模块外部修改参数的一种方法是显式地将其展平并创建一个可变字典。请注意,您可以使用分隔符 sep 来连接所有嵌套的键。如果未给出 sep,则键将是所有嵌套键的元组。

# Get a flattened key-value list.
flat_params = traverse_util.flatten_dict(params, sep='/')

jax.tree_util.tree_map(jnp.shape, flat_params)

现在您可以对参数执行任何操作。完成后,将其展开并将其用于未来的训练。

# Somehow modify a layer
dense_kernel = flat_params['Dense_1/kernel']
flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)

# Unflatten.
unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')
# Refreeze.
unflat_params = freeze(unflat_params)
jax.tree_util.tree_map(jnp.shape, unflat_params)

使用优化器进行手术#

当使用 Optax 作为优化器时,opt_state 实际上是构成优化器的各个梯度变换状态的嵌套元组。这些状态包含镜像参数树的 pytree,并且可以以相同的方式进行修改:展平、修改、展开,然后重新创建一个镜像原始状态的新优化器状态。

tx = optax.adam(1.0)
opt_state = tx.init(params)

# The optimizer state is a tuple of gradient transformation states.
jax.tree_util.tree_map(jnp.shape, opt_state)

优化器状态内的 pytree 遵循与参数相同的结构,并且可以以完全相同的方式展平/修改。

flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')
flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')

jax.tree_util.tree_map(jnp.shape, flat_mu)

修改后,重新创建优化器状态。将其用于未来的训练。

opt_state = (
    opt_state[0]._replace(
        mu=traverse_util.unflatten_dict(flat_mu, sep='/'),
        nu=traverse_util.unflatten_dict(flat_nu, sep='/'),
    ),
) + opt_state[1:]
jax.tree_util.tree_map(jnp.shape, opt_state)