flax.traverse_util 包#
用于遍历不可变数据结构的工具。
Traversal 可以用来迭代和更新复杂的数据结构。Traversal 接受一个对象并返回其内容的子集。例如,Traversal 可以选择对象的属性
>>> from flax import traverse_util
>>> import dataclasses
>>> @dataclasses.dataclass
... class Foo:
... foo: int = 0
... bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]
可以使用组合来构建更复杂的遍历。从身份遍历开始并使用方法链来构建所需的 Traversal 通常很有用。
>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]
Traversal 也可以用来使用 update
方法进行更改
>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}
Traversal 从不改变原始数据。因此,更新本质上会返回包含提供更新的数据的副本。
遍历对象#
- class flax.traverse_util.Traversal(*args, **kwargs)[source]#
所有遍历的基类。
字典工具#
- flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#
扁平化嵌套字典。
嵌套键被扁平化为一个元组。有关如何恢复嵌套字典结构,请参见
unflatten_dict
。示例
>>> from flax.traverse_util import flatten_dict >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = flatten_dict(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2}
请注意,空字典将被忽略,并且不会由
unflatten_dict
恢复。- 参数
xs – 嵌套字典
keep_empty_nodes – 用
traverse_util.empty_node
替换空字典。is_leaf – 一个可选函数,它接受下一个嵌套字典和嵌套键,如果嵌套字典是叶子(即,不应该被进一步扁平化),则返回 True。
sep – 如果指定,则返回字典的键将是
sep
连接的字符串(如果为None
,则键将是元组)。
- 返回值
扁平化的字典。
- flax.traverse_util.unflatten_dict(xs, sep=None)[source]#
反扁平化字典。
参见
flatten_dict
示例
>>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = unflatten_dict(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}}
- 参数
xs – 扁平化的字典
sep – 分隔符(与
flatten_dict()
中使用的相同)。
- 返回值
嵌套字典。
- flax.traverse_util.path_aware_map(f, nested_dict)[source]#
一个映射函数,它在嵌套字典结构上运行,同时考虑每个叶子的路径。
示例
>>> import jax.numpy as jnp >>> from flax import traverse_util >>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}} >>> f = lambda path, x: x + 5 if 'x' in path else -x >>> traverse_util.path_aware_map(f, params) {'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
- 参数
f – 一个可调用对象,它接受
(path, value)
参数并将它们映射到一个新值。这里path
是字符串元组。nested_dict – 嵌套字典结构。
- 返回值
具有映射值的新的嵌套字典结构。
模型参数遍历#
- class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[source]#
使用名称过滤器选择模型参数。
此遍历在参数的嵌套字典上运行,并根据
filter_fn
参数选择子集。参见
flax.optim.MultiOptimizer
,了解如何使用ModelParamTraversal
来使用特定优化器更新参数树的子集。