flax.cursor 包#

Cursor API 允许对 pytree 进行可变操作。与进行许多嵌套的 dataclasses.replace 调用相比,此 API 为对深度嵌套的不可变数据结构进行部分更新提供了一种更符合人体工程学的设计解决方案。

为了说明,请考虑以下示例

>>> from flax.cursor import cursor
>>> import dataclasses
>>> from typing import Any

>>> @dataclasses.dataclass(frozen=True)
>>> class A:
...   x: Any

>>> a = A(A(A(A(A(A(A(0)))))))

要使用 dataclasses.replace 替换 int 0,我们必须编写许多嵌套调用

>>> a2 = dataclasses.replace(
...   a,
...   x=dataclasses.replace(
...     a.x,
...     x=dataclasses.replace(
...       a.x.x,
...       x=dataclasses.replace(
...         a.x.x.x,
...         x=dataclasses.replace(
...           a.x.x.x.x,
...           x=dataclasses.replace(
...             a.x.x.x.x.x,
...             x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
...           ),
...         ),
...       ),
...     ),
...   ),
... )

使用 Cursor API 可以更轻松地实现等效功能

>>> a3 = cursor(a).x.x.x.x.x.x.x.set(1)
>>> assert a2 == a3

Cursor 对象跟踪对它所做的更改,当调用 .build 时,会生成一个包含累积更改的新对象。基本用法包括将对象包装在 Cursor 中,对 Cursor 对象进行更改并生成包含累积更改的原始对象的副本。

flax.cursor.cursor(obj)[source]#

Cursor 包装到 obj 上并返回它。然后可以通过以下方式将更改应用于 Cursor 对象

  • 通过 .set 方法进行单行更改

  • 进行多次更改,然后调用 .build 方法

  • 通过 .apply_update 方法根据 pytree 路径和节点值进行多次更改,然后调用 .build 方法

.set 示例

>>> from flax.cursor import cursor

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

.build 示例

>>> from flax.cursor import cursor

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

.apply_update 示例

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> def update_fn(path, value):
...   '''Replace params with empty dictionary.'''
...   if 'params' in path:
...     return {}
...   return value

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params={'a': 1, 'b': 2},
...     tx=optax.adam(1e-3),
... )
>>> c = cursor(state)
>>> state2 = c.apply_update(update_fn).build()
>>> assert state2.params == {}
>>> assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged

如果底层 objlisttuple,还可以迭代 Cursor 对象以获取子 Cursor

>>> from flax.cursor import cursor

>>> c = cursor(((1, 2), (3, 4)))
>>> for child_c in c:
...   child_c[1] *= -1
>>> assert c.build() == ((1, -2), (3, -4))

查看每个方法的文档字符串以查看更多使用示例。

参数

obj – 要用 Cursor 包装的对象

返回值

一个围绕 obj 包装的 Cursor 对象。

class flax.cursor.Cursor(obj, parent_key)[source]#
apply_update(update_fn)[source]#

遍历 Cursor 对象,并通过一个 update_fn 递归地记录条件更改。更改记录在 Cursor 对象的 ._changes 字典中。要生成一个包含累积更改的原始对象的副本,请在调用 .apply_update 后调用 .build 方法。

update_fn 的函数签名为 (str, Any) -> Any

  • 输入参数是当前密钥路径(以 '/' 分隔的字符串形式)和该当前密钥路径处的值

  • 输出是新值(要么由 update_fn 修改,要么如果条件不满足则与输入值相同)

注意

  • 如果 update_fn 返回一个修改后的值,此方法将不会进一步递归到该分支以记录更改。例如,如果我们打算用一个 int 替换指向字典的属性,我们不需要在字典中查找更多更改,因为字典将被替换。

  • 使用 is 运算符来确定返回值是否已修改(通过将其与输入值进行比较)。因此,如果 update_fn 修改了可变容器(例如列表、字典等)并返回了相同的容器,则 .apply_update 会将返回值视为未修改,因为它包含相同的 id。为避免这种情况,请返回修改后的值的副本。

  • .apply_update 不会调用 update_fn 到 pytree 最顶层的值(即根节点)。update_fn 将首先在根节点的子节点上调用,然后 pytree 遍历将从那里递归继续。

示例

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params']

>>> def update_fn(path, value):
...   '''Multiply all dense kernel params by 2 and add 1.
...   Subtract the Dense_1 bias param by 1.'''
...   if 'kernel' in path:
...     return value * 2 + 1
...   elif 'Dense_1' in path and 'bias' in path:
...     return value - 1
...   return value

>>> c = cursor(params)
>>> new_params = c.apply_update(update_fn).build()
>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
...   if layer == 'Dense_1':
...     assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all()
...   else:
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.key(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
参数

update_fn – 将有条件地记录对 Cursor 对象的更改的函数

返回值

具有由 update_fn 指定的已记录条件更改的当前 Cursor 对象。要生成一个包含累积更改的原始对象的副本,请在调用 .apply_update 后调用 .build 方法。

build()[source]#

创建一个包含累积更改的原始对象的副本并返回它。此方法应在对 Cursor 对象进行更改后调用。

注意

新对象自下而上构建,更改将首先应用于叶子节点,然后是其父节点,一直到根节点。

示例

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params=dict_obj,
...     tx=optax.adam(1e-3),
... )
>>> new_fn = lambda x: x + 1
>>> c = cursor(state)
>>> c.params['b'][1] = 10
>>> c.apply_fn = new_fn
>>> modified_state = c.build()
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
>>> assert modified_state.apply_fn == new_fn
返回值

包含累积更改的原始对象的副本。

find(cond_fn)[source]#

遍历 Cursor 对象,并返回满足 cond_fn 中条件的子 Cursor 对象。cond_fn 的函数签名为 (str, Any) -> bool

  • 输入参数是当前密钥路径(以 '/' 分隔的字符串形式)和该当前密钥路径处的值

  • 输出是一个布尔值,表示是否返回此路径处的子 Cursor 对象

如果在某个特定密钥路径处 cond_fn 的值为 True,则此方法将不会进一步递归到该分支;即,此方法将在特定密钥路径中找到并返回满足 cond_fn 条件的“最早”子节点。

注意

  • 如果在某个特定密钥路径处 cond_fn 的值为 True,则此方法将不会进一步递归到该分支;即,此方法将在特定密钥路径中找到并返回满足 cond_fn 条件的“最早”子节点。

  • .find 不会搜索 pytree 最顶层的值(即根节点)。cond_fn 将递归地进行评估,从根节点的子节点开始。

示例

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

>>> def cond_fn(path, value):
...   '''Find the second dense layer params.'''
...   return 'Dense_1' in path

>>> new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1)

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   if layer == 'Dense_1':
...     assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()
...   else:
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> c = cursor(params)
>>> c2 = c.find(cond_fn)
>>> c2['kernel'] += 2
>>> c2['bias'] += 2
>>> new_params = c.build()

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   if layer == 'Dense_1':
...     assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all()
...     assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all()
...   else:
...     assert (new_params[layer]['kernel'] == params[layer]['kernel']).all()
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
参数

cond_fn – 将有条件地查找子 Cursor 对象的函数

返回值

一个满足 cond_fn 条件的子 Cursor 对象。

find_all(cond_fn)[source]#

遍历 Cursor 对象,并返回一个满足 cond_fn 中条件的子 Cursor 对象的生成器。cond_fn 的函数签名为 (str, Any) -> bool

  • 输入参数是当前密钥路径(以 '/' 分隔的字符串形式)和该当前密钥路径处的值

  • 输出是一个布尔值,表示是否返回此路径处的子 Cursor 对象

注意

  • 如果在特定键路径下 cond_fn 的值为 True,则此方法不会进一步递归该分支;即此方法将在特定键路径中找到并返回满足 cond_fn 条件的“最早”子节点。

  • .find_all 不会搜索 pytree 最顶层的值(即根节点)。cond_fn 将从根节点的子节点开始递归评估。

示例

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

>>> def cond_fn(path, value):
...   '''Find all dense layer params.'''
...   return 'Dense' in path

>>> c = cursor(params)
>>> for dense_params in c.find_all(cond_fn):
...   dense_params['bias'] += 1
>>> new_params = c.build()

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
参数

cond_fn – 将有条件地查找子 Cursor 对象的函数

返回值

一个包含满足 cond_fn 条件的子 Cursor 对象的生成器。

set(value)[source]#

为 Cursor 对象中的属性、特性、元素或条目设置一个新值,并返回包含新设置值的原始对象的副本。

示例

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params=dict_obj,
...     tx=optax.adam(1e-3),
... )
>>> modified_state = cursor(state).params['b'][1].set(10)
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
参数

value – 用于在 Cursor 对象中设置属性、特性、元素或条目的值。

返回值

包含新设置值的原始对象的副本。