flax.traverse_util 包#

用于遍历不可变数据结构的工具。

Traversal 可以用来迭代和更新复杂的数据结构。Traversal 接受一个对象并返回其内容的子集。例如,Traversal 可以选择对象的属性

>>> from flax import traverse_util
>>> import dataclasses

>>> @dataclasses.dataclass
... class Foo:
...   foo: int = 0
...   bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]

可以使用组合来构建更复杂的遍历。从身份遍历开始并使用方法链来构建所需的 Traversal 通常很有用。

>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]

Traversal 也可以用来使用 update 方法进行更改

>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}

Traversal 从不改变原始数据。因此,更新本质上会返回包含提供更新的数据的副本。

遍历对象#

class flax.traverse_util.Traversal(*args, **kwargs)[source]#

所有遍历的基类。

compose(other)[source]#

组合两个遍历。

each()[source]#

遍历选定容器中的每个项目。

filter(fn)[source]#

过滤选定的值。

abstract iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

merge(*traversals)[source]#

组合任意数量的遍历并合并结果。

set(values, inputs)[source]#

覆盖 Traversal 选择的值。

参数
  • values – 包含新值的列表。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

tree()[source]#

遍历 pytree 中的每个项目。

abstract update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseId(*args, **kwargs)[source]#

身份 Traversal。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseMerge(*args, **kwargs)[source]#

合并一组遍历中的选择。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseCompose(*args, **kwargs)[source]#

组合两个遍历。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseFilter(*args, **kwargs)[source]#

根据谓词过滤选定的值。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseAttr(*args, **kwargs)[source]#

遍历对象的属性。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseItem(*args, **kwargs)[source]#

遍历对象的项目。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseEach(*args, **kwargs)[source]#

遍历容器中的每个项目。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

class flax.traverse_util.TraverseTree(*args, **kwargs)[source]#

遍历pytree中的每个项目。

iterate(inputs)[source]#

迭代此 Traversal 选择的值。

参数

inputs – 应该被遍历的对象。

返回值

遍历值的迭代器。

update(fn, inputs)[source]#

更新聚焦的项目。

参数
  • fn – 将每个遍历的项目映射到其更新值的回调函数。

  • inputs – 应该被遍历的对象。

返回值

具有更新值的新对象。

字典工具#

flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#

扁平化嵌套字典。

嵌套键被扁平化为一个元组。有关如何恢复嵌套字典结构,请参见 unflatten_dict

示例

>>> from flax.traverse_util import flatten_dict

>>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
>>> flat_xs = flatten_dict(xs)
>>> flat_xs
{('foo',): 1, ('bar', 'a'): 2}

请注意,空字典将被忽略,并且不会由 unflatten_dict 恢复。

参数
  • xs – 嵌套字典

  • keep_empty_nodes – 用 traverse_util.empty_node 替换空字典。

  • is_leaf – 一个可选函数,它接受下一个嵌套字典和嵌套键,如果嵌套字典是叶子(即,不应该被进一步扁平化),则返回 True。

  • sep – 如果指定,则返回字典的键将是 sep 连接的字符串(如果为 None,则键将是元组)。

返回值

扁平化的字典。

flax.traverse_util.unflatten_dict(xs, sep=None)[source]#

反扁平化字典。

参见 flatten_dict

示例

>>> flat_xs = {
...   ('foo',): 1,
...   ('bar', 'a'): 2,
... }
>>> xs = unflatten_dict(flat_xs)
>>> xs
{'foo': 1, 'bar': {'a': 2}}
参数
  • xs – 扁平化的字典

  • sep – 分隔符(与 flatten_dict() 中使用的相同)。

返回值

嵌套字典。

flax.traverse_util.path_aware_map(f, nested_dict)[source]#

一个映射函数,它在嵌套字典结构上运行,同时考虑每个叶子的路径。

示例

>>> import jax.numpy as jnp
>>> from flax import traverse_util

>>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}}
>>> f = lambda path, x: x + 5 if 'x' in path else -x
>>> traverse_util.path_aware_map(f, params)
{'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
参数
  • f – 一个可调用对象,它接受 (path, value) 参数并将它们映射到一个新值。这里 path 是字符串元组。

  • nested_dict – 嵌套字典结构。

返回值

具有映射值的新的嵌套字典结构。

模型参数遍历#

class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[source]#

使用名称过滤器选择模型参数。

此遍历在参数的嵌套字典上运行,并根据 filter_fn 参数选择子集。

参见 flax.optim.MultiOptimizer,了解如何使用 ModelParamTraversal 来使用特定优化器更新参数树的子集。

__init__(filter_fn)[source]#

构造一个新的 ModelParamTraversal。

参数

filter_fn – 一个函数,它接受参数的完整名称及其值,并返回是否应该选择此参数。参数的名称由模块层次结构和参数名称决定(例如:‘/module/sub_module/parameter_name’)。