提取中间值#
本指南将向您展示如何从模块中提取中间值。让我们从这个使用 nn.compact
的简单 CNN 开始。
from flax import linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
因为此模块使用 nn.compact
,所以我们无法直接访问中间值。有一些方法可以公开它们
将中间值存储在新变量集合中#
可以使用对 sow
的调用来增强 CNN,以存储如下所示的中间值
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
class SowCNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
self.sow('intermediates', 'features', x)
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
sow
在变量集合不可变时充当无操作。因此,它非常适合于调试和可选地跟踪中间值。'intermediates' 集合也由 capture_intermediates
API 使用(请参阅使用 capture_intermediates 部分)。
请注意,默认情况下,sow
在每次调用时都会追加值
这是必要的,因为一旦实例化,模块可以在其父模块中多次调用,我们希望捕获所有播种的值。
因此,您需要确保您**不要**将中间值反馈到
variables
中。否则,每次调用都会增加该元组的长度并触发重新编译。要覆盖默认的追加行为,请指定
init_fn
和reduce_fn
- 请参阅Module.sow()
。
class SowCNN2(nn.Module):
@nn.compact
def __call__(self, x):
mod = SowCNN(name='SowCNN')
return mod(x) + mod(x) # Calling same module instance twice.
@jax.jit
def init(key, x):
variables = SowCNN2().init(key, x)
# By default the 'intermediates' collection is not mutable during init.
# So variables will only contain 'params' here.
return variables
@jax.jit
def predict(variables, x):
# If mutable='intermediates' is not specified, then .sow() acts as a noop.
output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates')
features = mod_vars['intermediates']['SowCNN']['features']
return output, features
batch = jnp.ones((1,28,28,1))
variables = init(jax.random.key(0), batch)
preds, feats = predict(variables, batch)
assert len(feats) == 2 # Tuple with two values since module was called twice.
将模块重构为子模块#
对于清楚地了解如何拆分子模块的情况,这是一种有用的模式。在 setup
中公开的任何子模块都可以直接使用。在极限情况下,您可以在 setup
中定义所有子模块,并完全避免使用 nn.compact
。
class RefactoredCNN(nn.Module):
def setup(self):
self.features = Features()
self.classifier = Classifier()
def __call__(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class Features(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
return x
class Classifier(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
@jax.jit
def init(key, x):
variables = RefactoredCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return RefactoredCNN().apply({"params": params}, x,
method=lambda module, x: module.features(x))
params = init(jax.random.key(0), batch)
features(params, batch)
使用 capture_intermediates
#
Linen 支持自动捕获来自子模块的中间返回值,无需任何代码更改。此模式应被视为捕获中间值的“重锤”方法。作为调试和检查工具,它非常有用,但使用本指南中描述的其他模式将使您能够更细粒度地控制要提取哪些中间值。
在下面的代码示例中,我们检查任何中间激活是否是非有限的(NaN 或无穷大)
@jax.jit
def init(key, x):
variables = CNN().init(key, x)
return variables
@jax.jit
def predict(variables, x):
y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
intermediates = state['intermediates']
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin
variables = init(jax.random.key(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non-finite intermediate detected!"
默认情况下,仅收集 __call__
方法的中间值。或者,您可以根据 Module
实例和方法名称传递自定义过滤器函数。
filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"
y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']
请注意,capture_intermediates
仅适用于层。您可以使用 self.sow
手动存储非层中间值,但过滤器函数不会应用于它。
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b # not a Flax layer, so won't be stored as an intermediate
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
return Model().apply({"params": params}, x, capture_intermediates=True)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model was not stored because it's not a Flax layer
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b
self.sow('intermediates', 'c', c) # store intermediate c
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
# filter specifically for only the Dense_0 and Dense_2 layer
filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'})
return Model().apply({"params": params}, x, capture_intermediates=filter_fn)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model is stored and isn't filtered out by the filter function
为了将从 self.sow
提取的中间值与从 capture_intermediates
提取的中间值分开,我们可以定义一个单独的集合,例如 self.sow('sow_intermediates', 'c', c)
,或者在调用 .apply()
后手动过滤掉中间值。例如
flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/')
flattened_dict['c']
在效率方面,只要所有内容都经过 jit 处理,那么您最终不使用的任何中间值都应该由 XLA 优化掉。
使用 Sequential
#
您还可以使用 Sequential
组合器(这在更多有状态的方法中很常见)的简单实现来定义 CNN
。这可能对非常简单的模型有用,并为您提供任意模型手术。但它可能非常有限——如果您甚至想添加一个条件,则必须从 Sequential
中重构并更明确地构建您的模型。
class Sequential(nn.Module):
layers: Sequence[nn.Module]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def SeqCNN():
return Sequential([
nn.Conv(features=32, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(features=64, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(features=256),
nn.relu,
nn.Dense(features=10),
nn.log_softmax,
])
@jax.jit
def init(key, x):
variables = SeqCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)
batch = jnp.ones((1,28,28,1))
params = init(jax.random.key(0), batch)
features(params, batch)
提取中间值的梯度#
出于调试目的,提取中间值的梯度可能很有用。这可以通过在所需值上使用Module.perturb()
方法来完成。
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(8)(x))
x = self.perturb('hidden', x)
x = nn.Dense(2)(x)
x = self.perturb('logits', x)
return x
perturb
默认情况下会将变量添加到 perturbations
集合中,它的行为类似于恒等函数,并且扰动的梯度与输入的梯度匹配。要获得扰动,只需初始化模型即可
x = jnp.empty((1, 4)) # random data
y = jnp.empty((1, 2)) # random data
model = Model()
variables = model.init(jax.random.key(1), x)
params, perturbations = variables['params'], variables['perturbations']
最后计算损失相对于扰动的梯度,这些梯度将与中间值的梯度匹配
def loss_fn(params, perturbations, x, y):
y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
return jnp.mean((y_pred - y) ** 2)
intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)