使用过滤器#
注意:此页面与新的 Flax NNX API 相关。
过滤器在 Flax NNX 中被广泛用作在诸如 nnx.split
、nnx.state
和许多 Flax NNX 转换之类的 API 中创建 State
组的方法。例如
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
这里 nnx.Param
和 nnx.BatchStat
被用作过滤器,将模型分成两组:一组包含参数,另一组包含批次统计信息。然而,这引出了以下问题
什么是过滤器?
为什么诸如
Param
或BatchStat
之类的类型是过滤器?如何对
State
进行分组/过滤?
过滤器协议#
通常,过滤器是以下形式的谓词函数
(path: tuple[Key, ...], value: Any) -> bool
其中 Key
是可哈希和可比较的类型,path
是表示嵌套结构中值路径的 Key
元组,而 value
是路径上的值。如果该值应包含在该组中,则该函数返回 True
,否则返回 False
。
类型显然不是这种形式的函数,因此它们被视为过滤器的原因是,正如我们接下来将看到的,类型和一些其他字面量被转换为谓词。例如,Param
大致转换为如下的谓词
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param) or (
hasattr(value, 'type') and issubclass(value.type, nnx.Param)
)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
此类函数匹配任何 Param
的实例或具有作为 Param
子类的 type
属性的任何值。Flax NNX 内部使用 OfType
,它为给定类型定义了这种形式的可调用对象
is_param = nnx.OfType(nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
过滤器 DSL#
为了避免用户创建这些函数,Flax NNX 公开了一个小型 DSL,它被形式化为 nnx.filterlib.Filter
类型,它允许用户传递类型、布尔值、省略号、元组/列表等,并在内部将它们转换为适当的谓词。
以下是 Flax NNX 中包含的所有可调用过滤器及其 DSL 字面量(如果可用)的列表
字面量 |
可调用对象 |
描述 |
---|---|---|
|
|
匹配所有值 |
|
|
不匹配任何值 |
|
|
匹配属于 |
|
匹配具有包含给定 |
|
|
|
匹配具有等于 |
|
|
匹配与任何内部 |
|
匹配与所有内部 |
|
|
匹配与内部 |
让我们通过一个 nnx.vmap
示例来了解 DSL 的实际应用。假设我们想要将所有参数和 dropout
Rng(Keys|Counts) 在第 0 轴上矢量化,并将其余部分广播。为此,我们可以使用以下过滤器来定义一个 nnx.StateAxes
对象,该对象可以传递给 nnx.vmap
的 in_axes
,以指定如何对 model
的各种子状态进行矢量化
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})
@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
...
这里 (nnx.Param, 'dropout')
扩展为 Any(OfType(nnx.Param), WithTag('dropout'))
,而 ...
扩展为 Everything()
。
如果您希望手动将字面量转换为谓词,可以使用 nnx.filterlib.to_predicate
is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))
print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.variablelib.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))
分组状态#
掌握了过滤器的知识后,让我们看看 nnx.split
的大致实现方式。关键思想
使用
nnx.graph.flatten
获取节点的GraphDef
和State
表示形式。将所有过滤器转换为谓词。
使用
State.flat_state
获取状态的扁平表示形式。遍历扁平状态中的所有
(path, value)
对,并根据谓词对其进行分组。使用
State.from_flat_state
将扁平状态转换为嵌套的State
。
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state():
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
break
else:
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
)
return graphdef, *states
# lets test it...
foo = Foo()
graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
需要注意的一件非常重要的事情是,过滤是顺序相关的。第一个匹配一个值的过滤器会保留它,因此您应该将更具体的过滤器放在更通用的过滤器之前。例如,如果我们创建一个 SpecialParam
类型,它是 Param
的子类,以及一个包含两种参数类型的 Bar
对象,如果我们尝试在 SpecialParam
之前拆分 Param
,则所有值都将放置在 Param
组中,而 SpecialParam
组将为空,因为所有 SpecialParam
也是 Param
class SpecialParam(nnx.Param):
pass
class Bar(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = SpecialParam(0)
bar = Bar()
graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
),
'b': VariableState(
type=SpecialParam,
value=0
)
})
special_params = State({})
反转顺序将确保首先捕获 SpecialParam
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
special_params = State({
'b': VariableState(
type=SpecialParam,
value=0
)
})