使用过滤器

使用过滤器#

注意:此页面与新的 Flax NNX API 相关。

过滤器在 Flax NNX 中被广泛用作在诸如 nnx.splitnnx.state 和许多 Flax NNX 转换之类的 API 中创建 State 组的方法。例如

from flax import nnx

class Foo(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = nnx.BatchStat(True)

foo = Foo()

graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})

这里 nnx.Paramnnx.BatchStat 被用作过滤器,将模型分成两组:一组包含参数,另一组包含批次统计信息。然而,这引出了以下问题

  • 什么是过滤器?

  • 为什么诸如 ParamBatchStat 之类的类型是过滤器?

  • 如何对 State 进行分组/过滤?

过滤器协议#

通常,过滤器是以下形式的谓词函数


(path: tuple[Key, ...], value: Any) -> bool

其中 Key 是可哈希和可比较的类型,path 是表示嵌套结构中值路径的 Key 元组,而 value 是路径上的值。如果该值应包含在该组中,则该函数返回 True,否则返回 False

类型显然不是这种形式的函数,因此它们被视为过滤器的原因是,正如我们接下来将看到的,类型和一些其他字面量被转换为谓词。例如,Param 大致转换为如下的谓词

def is_param(path, value) -> bool:
  return isinstance(value, nnx.Param) or (
    hasattr(value, 'type') and issubclass(value.type, nnx.Param)
  )

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True

此类函数匹配任何 Param 的实例或具有作为 Param 子类的 type 属性的任何值。Flax NNX 内部使用 OfType,它为给定类型定义了这种形式的可调用对象

is_param = nnx.OfType(nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True

过滤器 DSL#

为了避免用户创建这些函数,Flax NNX 公开了一个小型 DSL,它被形式化为 nnx.filterlib.Filter 类型,它允许用户传递类型、布尔值、省略号、元组/列表等,并在内部将它们转换为适当的谓词。

以下是 Flax NNX 中包含的所有可调用过滤器及其 DSL 字面量(如果可用)的列表

字面量

可调用对象

描述

...True

Everything()

匹配所有值

NoneFalse

Nothing()

不匹配任何值

type

OfType(type)

匹配属于 type 实例或具有作为 type 实例的 type 属性的值

PathContains(key)

匹配具有包含给定 key 的关联 path 的值

'{filter}' str

WithTag('{filter}')

匹配具有等于 '{filter}' 的字符串 tag 属性的值。由 RngKeyRngCount 使用。

(*filters) tuple[*filters] list

Any(*filters)

匹配与任何内部 filters 匹配的值

All(*filters)

匹配与所有内部 filters 匹配的值

Not(filter)

匹配与内部 filter 不匹配的值

让我们通过一个 nnx.vmap 示例来了解 DSL 的实际应用。假设我们想要将所有参数和 dropout Rng(Keys|Counts) 在第 0 轴上矢量化,并将其余部分广播。为此,我们可以使用以下过滤器来定义一个 nnx.StateAxes 对象,该对象可以传递给 nnx.vmapin_axes,以指定如何对 model 的各种子状态进行矢量化

state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})

@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
  ...

这里 (nnx.Param, 'dropout') 扩展为 Any(OfType(nnx.Param), WithTag('dropout')),而 ... 扩展为 Everything()

如果您希望手动将字面量转换为谓词,可以使用 nnx.filterlib.to_predicate

is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))

print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.variablelib.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))

分组状态#

掌握了过滤器的知识后,让我们看看 nnx.split 的大致实现方式。关键思想

  • 使用 nnx.graph.flatten 获取节点的 GraphDefState 表示形式。

  • 将所有过滤器转换为谓词。

  • 使用 State.flat_state 获取状态的扁平表示形式。

  • 遍历扁平状态中的所有 (path, value) 对,并根据谓词对其进行分组。

  • 使用 State.from_flat_state 将扁平状态转换为嵌套的 State

from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]

def split(node, *filters):
  graphdef, state = nnx.graph.flatten(node)
  predicates = [nnx.filterlib.to_predicate(f) for f in filters]
  flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]

  for path, value in state.flat_state():
    for i, predicate in enumerate(predicates):
      if predicate(path, value):
        flat_states[i][path] = value
        break
    else:
      raise ValueError(f'No filter matched {path = } {value = }')

  states: tuple[nnx.GraphState, ...] = tuple(
    nnx.State.from_flat_path(flat_state) for flat_state in flat_states
  )
  return graphdef, *states

# lets test it...
foo = Foo()

graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})

需要注意的一件非常重要的事情是,过滤是顺序相关的。第一个匹配一个值的过滤器会保留它,因此您应该将更具体的过滤器放在更通用的过滤器之前。例如,如果我们创建一个 SpecialParam 类型,它是 Param 的子类,以及一个包含两种参数类型的 Bar 对象,如果我们尝试在 SpecialParam 之前拆分 Param,则所有值都将放置在 Param 组中,而 SpecialParam 组将为空,因为所有 SpecialParam 也是 Param

class SpecialParam(nnx.Param):
  pass

class Bar(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = SpecialParam(0)

bar = Bar()

graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  ),
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})
special_params = State({})

反转顺序将确保首先捕获 SpecialParam

graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
special_params = State({
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})