Module#
Flax 模块系统。
- class flax.linen.Module[source]#
所有神经网络模块的基类。
层和模型应该继承此类。
所有 Flax 模块都是 Python 3.7 数据类。由于数据类接管了
__init__
,因此您应该改写setup()
,它会自动被调用以初始化模块。模块可以包含子模块,并且可以通过这种方式嵌套在树状结构中。子模块可以作为
setup()
方法内的常规属性进行分配。您可以在模块子类上定义任意“前向传递”方法。虽然没有方法是特殊处理的,但
__call__
是一个常用的选择,因为它允许您将模块实例用作函数。>>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x)))
可选地,对于子模块定义与其使用方式位于同一位置的更简洁的模块实现,您可以使用
compact()
包装器。- __setattr__(name, val)[source]#
在该模块上设置一个属性。
我们重载 setattr 仅仅是为了通过在特殊的
setup()
函数中分配子模块来支持 Python 命名self.submodule_name = MyModule(...)
我们还支持列表和其他通用 PyTree,例如:
self.submodules = [MyModule0(..), MyModule1(..), ...]
- 参数
name – 要设置的属性。
val – 属性的值。
- apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)[source]#
将模块方法应用于变量,并返回输出和修改后的变量。
请注意,如果要对不同于
__call__
的类方法调用apply
,则应设置method
。例如,假设 Transformer 模块有一个名为encode
的方法,那么以下调用将在该方法上调用apply
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Transformer(nn.Module): ... def encode(self, x): ... ... >>> x = jnp.ones((16, 9)) >>> model = Transformer() >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) >>> encoded = model.apply(variables, x, method=Transformer.encode)
如果提供函数实例,则使用未绑定的函数。例如,以下示例等效于上面的示例
>>> encoded = model.apply(variables, x, method=model.encode)
您还可以将字符串传递给模块的可调用属性。例如,前面的示例可以写成
>>> encoded = model.apply(variables, x, method='encode')
请注意
method
也可以是一个在Transformer
中未定义的函数。在这种情况下,该函数至少应有一个表示模块类的实例的参数>>> def other_fn(instance, x): ... # instance.some_module_attr(...) ... instance.encode ... ... >>> model.apply(variables, x, method=other_fn)
如果传递单个
PRNGKey
,Flax 将使用它来馈送'params'
RNG 流。如果您想使用其他 RNG 流或需要使用多个流,则可以将字典传递给apply
,该字典将每个 RNG 流名称映射到其对应的PRNGKey
。如果在用户未传递的 RNG 流名称上调用了self.make_rng(name)
,它将默认使用'params'
RNG 流。示例
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, add_noise=False): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... if add_noise: ... # Add gaussian noise ... noise_key = self.make_rng('noise') ... x = x + jax.random.normal(noise_key, x.shape) ... ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} >>> variables = module.init(rngs, x) >>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> rngs['noise'] = jax.random.key(0) >>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # different output (key(1) vs key(0)) >>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1) >>> del rngs['noise'] >>> # self.make_rng('noise') will default to using the 'params' RNG stream >>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # same output (key(0)) >>> np.testing.assert_allclose(out1, out2) >>> # passing in a single key is equivalent to passing in {'params': key} >>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0)) >>> # same output (key(0)) >>> np.testing.assert_allclose(out2, out3)
- 参数
variables – 包含变量的字典,按变量集合进行键控。有关变量的更多详细信息,请参阅
flax.core.variables
。*args – 传递给指定应用方法的命名参数。
rngs – 一个 PRNGKey 字典,用于初始化 PRNG 序列。“params” PRNG 序列用于初始化参数。
method – 要在上面调用应用的函数。这通常是模块中的函数。如果提供,则应用此方法。如果没有提供,则应用模块的
__call__
方法。还可以提供字符串来按名称指定方法。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。capture_intermediates – 如果为
True
,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,只存储所有__call__
方法的返回值。可以传递函数来更改过滤器行为。过滤器函数接受模块实例和方法名称,并返回一个 bool 值,表示是否应存储该方法调用的输出。**kwargs – 传递给指定应用方法的关键字参数。
- 返回
如果
mutable
为 False,则返回输出。如果任何集合是可变的,则返回(output, vars)
,其中vars
是修改后的集合的字典。
- bind(variables, *args, rngs=None, mutable=False)[source]#
通过绑定变量和 RNG 来创建交互式模块实例。
bind
直接提供模块的“交互式”实例,而无需使用apply
来转换函数。这对于调试和笔记本等交互式用例特别有用,在这些用例中,函数会限制将代码分成不同单元格的能力。一旦变量(以及可选的 RNG)绑定到
Module
,它就会变成一个有状态的对象。请注意,习惯性的 JAX 是函数式的,因此交互式实例与普通 JAX API 不太协调。bind()
仅应用于交互式实验,在所有其他情况下,我们强烈建议用户使用apply()
代替。示例
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = nn.Dense(3) ... self.decoder = nn.Dense(5) ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> x = jnp.ones((16, 9)) >>> ae = AutoEncoder() >>> variables = ae.init(jax.random.key(0), x) >>> model = ae.bind(variables) >>> z = model.encoder(x) >>> x_reconstructed = model.decoder(z)
- 参数
variables – 包含变量的字典,按变量集合进行键控。有关变量的更多详细信息,请参阅
flax.core.variables
。*args – 命名参数(未使用)。
rngs – 一个 PRNGKey 字典,用于初始化 PRNG 序列。
mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。
- 返回
具有绑定变量和 RNG 的该实例的副本。
- copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)[source]#
创建此模块的副本,可以选择更新参数。
- 参数
parent – 副本的父级。默认情况下,如果未明确指定,则当前模块将作为父级。
name – 复制模块的新名称,默认情况下将提供一个新的自动名称。
**updates – 属性更新。
- 返回
具有更新的名称、父级和属性的此模块的副本。
- get_variable(col, name, default=None)[source]#
检索变量的值。
- 参数
col – 变量集合。
name – 变量的名称。
default – 如果变量不存在于此范围内,则要返回的默认值。
- 返回
输入变量的值,如果变量不存在于此范围内,则为默认值。
- has_variable(col, name)[source]#
检查此模块中是否存在给定集合和名称的变量。
有关变量和集合的更多说明,请参见
flax.core.variables
。- 参数
col – 变量集合名称。
name – 变量的名称。
- 返回
如果变量存在,则为 True。
- init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#
使用变量初始化模块方法并返回修改后的变量。
init
的第一个参数是单个PRNGKey
,或者是一个字典,它将变量集合名称映射到它们的PRNGKeys
,并将调用method
(默认情况下是模块的__call__
函数),传递*args
和**kwargs
,并返回一个初始化变量的字典。示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(16)(x) ... x = nn.BatchNorm(use_running_average=not train)(x) ... x = nn.relu(x) ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> key = jax.random.key(0) >>> variables = module.init(key, x, train=False)
如果传递单个
PRNGKey
,Flax 将使用它来馈送'params'
RNG 流。如果您想使用不同的 RNG 流或需要使用多个流,可以将一个字典传递给init
,它将每个 RNG 流名称映射到其对应的PRNGKey
。如果对用户未传递的 RNG 流名称调用self.make_rng(name)
,它将默认使用'params'
RNG 流。示例
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... other_variable = self.variable( ... 'other_collection', ... 'other_variable', ... lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape), ... x, ... ) ... x = x + other_variable.value ... ... return nn.Dense(1)(x) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)} >>> variables0 = module.init(rngs, x) >>> rngs['other_rng'] = jax.random.key(0) >>> variables1 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables0['params'], variables1['params'] ... ) >>> # different other_variable (key(1) vs key(0)) >>> np.testing.assert_raises( ... AssertionError, ... np.testing.assert_allclose, ... variables0['other_collection']['other_variable'], ... variables1['other_collection']['other_variable'], ... ) >>> del rngs['other_rng'] >>> # self.make_rng('other_rng') will default to using the 'params' RNG stream >>> variables2 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables1['params'], variables2['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables1['other_collection']['other_variable'], ... variables2['other_collection']['other_variable'], ... ) >>> # passing in a single key is equivalent to passing in {'params': key} >>> variables3 = module.init(jax.random.key(0), x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables2['params'], variables3['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables2['other_collection']['other_variable'], ... variables3['other_collection']['other_variable'], ... )
Jitting
init
使用提供的参数的形状以延迟方式初始化模型,并避免使用实际值计算前向传递。示例>>> module = nn.Dense(1) >>> init_jit = jax.jit(module.init) >>> variables = init_jit(jax.random.key(0), x)
init
是apply
的一个轻量级包装器,因此其他apply
参数(如method
、mutable
和capture_intermediates
)也可用。- 参数
rngs – 变量集合的 rng。
*args – 传递给 init 函数的命名参数。
method – 可选方法。如果提供,则应用此方法。如果未提供,则应用
__call__
方法。也可以提供字符串来按名称指定方法。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合都是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为
True
,则在“intermediates”集合中捕获所有模块内部的中间返回值。默认情况下,只存储所有__call__
方法的返回值。可以传递一个函数来更改过滤行为。过滤函数接受模块实例和方法名称,并返回一个 bool 值,指示该方法调用输出是否应存储。**kwargs – 传递给 init 函数的关键字参数。
- 返回
初始化的变量字典。
- init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#
使用变量初始化模块方法并返回输出和修改后的变量。
- 参数
rngs – 变量集合的 rng。
*args – 传递给 init 函数的命名参数。
method – 可选方法。如果提供,则应用此方法。如果未提供,则应用
__call__
方法。也可以提供字符串来按名称指定方法。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合都是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为
True
,则在“intermediates”集合中捕获所有模块内部的中间返回值。默认情况下,只存储所有__call__
方法的返回值。可以传递一个函数来更改过滤行为。过滤函数接受模块实例和方法名称,并返回一个 bool 值,指示该方法调用输出是否应存储。**kwargs – 传递给 init 函数的关键字参数。
- 返回
(output, vars)
,其中vars
是修改后的集合的字典。
- is_initializing()[source]#
如果在 self.init(…) 或 nn.init(…)() 下运行,则返回 True。
这是一个帮助方法,用于处理简单初始化的常见情况,在这种情况下,我们希望在仅在
module.init
或nn.init
下调用时发生设置逻辑。对于更复杂的多分阶段初始化场景,最好测试特定变量集合的可变性,或者测试可能需要初始化的特定变量的存在。
- lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)[source]#
初始化模块,无需在实际输入上进行计算。
lazy_init 将初始化变量,而不会进行不必要的计算。输入数据应作为
jax.ShapeDtypeStruct
传递,它指定输入的形状和 dtype,但不提供任何具体数据。示例
>>> model = nn.Dense(features=256) >>> variables = model.lazy_init( ... jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))
传递给
lazy_init
的 args 和 kwargs args 可以是具体值(jax 数组、标量、布尔值)和抽象值(ShapeDtypeStruct)的混合。具体值仅对于影响变量初始化的参数是必需的。例如,模型可能需要一个关键字参数来启用/禁用模型的一部分。在这种情况下,应传递一个显式值(True/Flase),否则lazy_init
无法推断应初始化哪些变量。- 参数
rngs – 变量集合的 rng。
*args – 传递给 init 函数的参数。
method – 可选方法。如果提供,则应用此方法。如果未提供,则应用
__call__
方法。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合都是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。**kwargs – 传递给 init 函数的关键字参数。
- 返回
初始化的变量字典。
- make_rng(name='params')[source]#
从模块的给定 RNG 序列中返回一个新的 RNG 密钥。
新的 RNG 密钥是从上一个密钥中拆分的。因此,每次调用
make_rng
都会返回一个新的 RNG 密钥,同时仍然保证完全可重复性。注意
如果传递了一个无效的名称(即用户在
.init
或.apply
中没有为此名称传递 RNG 密钥),则name
将默认为'params'
。示例
>>> import jax >>> import flax.linen as nn >>> class ParamsModule(nn.Module): ... def __call__(self): ... return self.make_rng('params') >>> class OtherModule(nn.Module): ... def __call__(self): ... return self.make_rng('other') >>> key = jax.random.key(0) >>> params_out, _ = ParamsModule().init_with_output({'params': key}) >>> # self.make_rng('other') will default to using the 'params' RNG stream >>> other_out, _ = OtherModule().init_with_output({'params': key}) >>> assert params_out == other_out
阅读 Flax RNG 指南,了解有关 RNG 的更多信息:https://flax.org.cn/en/latest/guides/flax_fundamentals/rng_guide.html
- 参数
name – RNG 序列名称。
- 返回
新生成的 RNG 密钥。
- module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)[source]#
返回一个字典,将模块路径映射到模块实例。
此方法具有相同的签名,并在内部调用
Module.init
,但它不返回变量,而是返回一个字典,将模块路径映射到运行时使用的模块实例的无界副本。module_paths
使用jax.eval_shape
来运行前向计算,而不会消耗任何 FLOPs 或分配内存。示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> modules = Foo().module_paths(jax.random.key(0), x) >>> print({ ... p: type(m).__name__ for p, m in modules.items() ... }) {'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
- 参数
rngs – 变量集合的 rngs,如传递给
Module.init
。*args – 前向计算的参数。
show_repeated – 如果为
True
,则将显示对同一模块的重复调用,否则仅显示第一次调用。默认值为False
。mutable – 可以是布尔值、字符串或列表。指定哪些集合应该被视为可变的:
bool
:所有/无集合都是可变的。str
:单个可变集合的名称。list
:可变集合的名称列表。默认情况下,除了 ‘intermediates’ 之外的所有集合都是可变的。**kwargs – 传递给前向计算的关键字参数。
- 返回
一个字典,将模块路径映射到模块实例。
- param(name, init_fn, *init_args, unbox=True, **init_kwargs)[source]#
声明并返回此模块中的参数。
参数是名为 “params” 的集合中的只读变量。有关变量的更多详细信息,请参阅
flax.core.variables
。init_fn
的第一个参数被假定为 PRNG 密钥,该密钥会自动提供,无需使用init_args
或init_kwargs
传递。>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... mean = self.param('mean', nn.initializers.lecun_normal(), x.shape) ... ... ... return x * mean >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}
在上例中,函数
lecun_normal
期待两个参数:key
和shape
,但只需明确提供shape
;key
会使用用于params
的 PRNG 自动设置,该 PRNG 在使用init()
初始化模块时传递。- 参数
name – 参数名称。
init_fn – 将被调用以计算此变量初始值的函数。该函数将只在第一次在该模块中使用该参数时被调用。
*init_args – 传递给 init_fn 的位置参数。
unbox – 如果为 True,则
AxisMetadata
实例将被替换为其未装箱的值,请参阅flax.nn.meta.unbox
(默认值:True)。**init_kwargs – 传递给 init_fn 的关键字参数。
- 返回
已初始化参数的值。如果参数已存在,则会抛出错误。
- property path#
获取此模块的路径。顶级根模块具有空路径
()
。请注意,此方法只能用于具有有效范围的绑定模块。示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class SubModel(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'SubModel path: {self.path}') ... return x >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'Model path: {self.path}') ... return SubModel()(x) >>> model = Model() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 2))) Model path: () SubModel path: ('SubModel_0',)
- perturb(name, value, collection='perturbations')[source]#
向中间值添加一个零值变量(“扰动”)。
value
的梯度将与该扰动变量的梯度相同。因此,如果使用 params 和 perturbations 作为独立参数定义损失函数,则可以通过对 perturbation 参数运行jax.grad
来获取value
的中间梯度。注意
这是一个实验性 API,可能会在以后针对性能和可用性进行调整。在其当前阶段,它会创建额外的虚拟变量,这些变量会占用额外的内存空间。仅将其用于调试训练中的梯度。
示例
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = self.perturb('dense3', x) ... return nn.Dense(2)(x) >>> def loss(variables, inputs, targets): ... preds = model.apply(variables, inputs) ... return jnp.square(preds - targets).mean() >>> x = jnp.ones((2, 9)) >>> y = jnp.ones((2, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y) >>> print(intm_grads['perturbations']['dense3']) [[-1.456924 -0.44332537 0.02422847] [-1.456924 -0.44332537 0.02422847]]
如果未将 perturbations 传递给
apply
,则perturb
将表现为无操作,因此可以轻松地在不需要时禁用该行为>>> model.apply(variables, x) # works as expected Array([[-1.0980128 , -0.67961735], [-1.0980128 , -0.67961735]], dtype=float32) >>> model.apply({'params': variables['params']}, x) # behaves like a no-op Array([[-1.0980128 , -0.67961735], [-1.0980128 , -0.67961735]], dtype=float32) >>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y) >>> 'perturbations' not in intm_grads True
- put_variable(col, name, value)[source]#
如果给定变量是可变的,则更新其值,否则会抛出错误。
- 参数
col – 变量集合。
name – 变量的名称。
value – 变量的新值。
- setup()[source]#
延迟初始化模块(类似于延迟
__init__
)。setup
会在绑定模块时在模块实例上延迟调用一次,在调用任何其他方法(如__call__
)之前,或在访问setup
定义的self
上的属性之前。这种情况可能发生在三种情况下
在通过在另一个模块的
setup
方法中将其分配给另一个模块的属性来为模块命名后(请参阅__setattr__()
)。>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
在使用
compact()
包装的方法内构造模块后,在调用另一个方法或访问setup
定义的属性之前立即调用。
- sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#
将值存储在集合中。
集合可用于收集中间值,而无需在每个模块调用时明确传递容器。
如果目标集合不可变,则
sow
会表现为无操作,并返回False
。示例
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... self.sow('intermediates', 'h', h) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply(variables, x, mutable=['intermediates']) >>> jax.tree.map(jnp.shape, state['intermediates']) {'h': ((16, 4),)}
默认情况下,值将存储在一个元组中,每个存储的值都会附加到末尾。这样,当多次调用同一模块时,就可以跟踪所有中间值。或者,可以传递自定义的 init/reduce 函数。
>>> class Foo2(nn.Module): ... @nn.compact ... def __call__(self, x): ... init_fn = lambda: 0 ... reduce_fn = lambda a, b: a + b ... self.sow('intermediates', 'h', x, ... init_fn=init_fn, reduce_fn=reduce_fn) ... self.sow('intermediates', 'h', x * 2, ... init_fn=init_fn, reduce_fn=reduce_fn) ... return x >>> x = jnp.ones((1, 1)) >>> model = Foo2() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply( ... variables, x, mutable=['intermediates']) >>> print(state['intermediates']) {'h': Array([[3.]], dtype=float32)}
- 参数
col – 变量集合的名称。
name – 变量的名称。
value – 变量的值。
reduce_fn – 用于将现有值与新值结合的函数。默认情况下,将值附加到元组。
init_fn – 对于存储的第一个值,
reduce_fn
将会传递init_fn
的结果以及要存储的值。默认值为一个空元组。
- 返回
True
如果值已成功存储,否则为False
。
- tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#
创建以表格形式表示的模块摘要。
此方法具有与
Module.init
相同的签名,并且内部调用它,但是它不返回变量,而是返回以表格形式汇总模块的字符串。tabulate
使用jax.eval_shape
来运行前向计算,而不会消耗任何 FLOP 或分配内存。可以将其他参数传递到
console_kwargs
参数中,例如,{'width': 120}
。有关console_kwargs
参数的完整列表,请参见:https://rich.pythonlang.cn/en/stable/reference/console.html#rich.console.Console示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> # print(Foo().tabulate( >>> # jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))
这将产生以下输出
Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B)
**注意**:表格中的行顺序不代表执行顺序,而是与
variables
中按字母顺序排序的键的顺序一致。**注意**:如果模块不可微分,则
vjp_flops
将返回0
。- 参数
rngs – 变量集合的 rngs,如传递给
Module.init
。*args – 前向计算的参数。
**depth** – 控制摘要可以深入到多少个子模块。默认情况下,它是
None
,这意味着没有限制。如果子模块由于深度限制而未显示,则其参数计数和字节将添加到其第一个显示祖先的行,以使所有行的总和始终加起来等于模块的总参数数量。show_repeated – 如果为
True
,则将显示对同一模块的重复调用,否则仅显示第一次调用。默认值为False
。mutable – 可以是布尔值、字符串或列表。指定哪些集合应该被视为可变的:
bool
:所有/无集合都是可变的。str
:单个可变集合的名称。list
:可变集合的名称列表。默认情况下,除了 ‘intermediates’ 之外的所有集合都是可变的。**console_kwargs** – 一个可选字典,包含传递给
rich.console.Console
的其他关键字参数,用于渲染表格。默认参数为{'force_terminal': True, 'force_jupyter': False}
。**table_kwargs** – 一个可选字典,包含传递给
rich.table.Table
构造函数的其他关键字参数。**column_kwargs** – 一个可选字典,包含传递给
rich.table.Table.add_column
的其他关键字参数,用于向表格中添加列。**compute_flops** – 是否在表格中包含一个
flops
列,列出每个模块前向传递的估计 FLOP 成本。确实会产生实际的设备上计算/编译/内存分配,但对于大型模块仍会引入开销(例如,Stable Diffusion 的 UNet 额外需要 20 秒,而其他情况下制表将在 5 秒内完成)。**compute_vjp_flops** – 是否在表格中包含一个
vjp_flops
列,列出每个模块反向传递的估计 FLOP 成本。引入的计算开销约为compute_flops
的 2-3 倍。**kwargs – 传递给前向计算的关键字参数。
- 返回
一个总结模块的字符串。
- unbind()[source]#
返回一个模块及其变量的未绑定副本。
unbind
有助于创建绑定模块的无状态版本。一个常见用例的示例:提取在
setup()
中定义的子模块及其相应的变量:1)临时bind
父模块;然后 2)unbind
所需的子模块。(请记住,setup()
仅在绑定模块时才会被调用。)>>> class Encoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(256)(x) >>> class Decoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(784)(x) >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = Encoder() ... self.decoder = Decoder() ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> module = AutoEncoder() >>> variables = module.init(jax.random.key(0), jnp.ones((1, 784))) >>> # Extract the Encoder sub-Module and its variables >>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
- 返回
一个元组,其中包含该模块及其变量的未绑定副本。
- variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)[source]#
在此模块中声明并返回一个变量。
有关更多信息,请参见
flax.core.variables
。另请参见param()
,了解在“params”集合中定义只读变量的简写方式。与
param()
不同,使用init_fn
传递的所有参数都应明确传递。>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... key = self.make_rng('stats') ... mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape) ... ... ... return x * mean.value >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}
在上面的示例中,函数
lecun_normal
预期两个参数:key
和shape
,并且两者都必须传递。当调用init()
和apply()
时,必须显式提供stats
的 PRNG。- 参数
col – 变量集合名称。
**name** – 变量名称。
**init_fn** – 将被调用以计算此变量初始值的函数。此函数仅会在此模块中第一次使用此变量时被调用。如果为 None,则变量必须已初始化,否则将引发错误。
*init_args – 传递给 init_fn 的位置参数。
unbox – 如果为 True,则
AxisMetadata
实例将被替换为其未装箱的值,请参阅flax.nn.meta.unbox
(默认值:True)。**init_kwargs** – 要传递给 init_fn 的关键字参数。
- 返回
一个
flax.core.variables.Variable
,可以通过“”.value”属性进行读取或设置。如果变量已存在,则会抛出错误。
- property variables#
返回此模块中的变量。
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#
创建一个应用函数,以将绑定模块与
fn
一起调用。与
Module.apply
不同,此函数返回一个新函数,其签名为(variables, *args, rngs=None, **kwargs) -> T
,其中T
是fn
的返回类型。如果mutable
不是False
,则返回类型是一个元组,其中第二个项目是包含已变异变量的FrozenDict
。返回的应用函数可以直接与 JAX 变换组合,例如
jax.jit
>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- 参数
**fn** – 应应用的函数。传递的第一个参数将是一个模块实例,该实例是具有变量和 RNG 绑定的
module
。**module** – 将用于将变量和 RNG 绑定的
Module
。作为fn
的第一个参数传递的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。**capture_intermediates** – 如果为
True
,则捕获所有模块在“intermediates”集合中的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数获取模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的应用函数。
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
创建一个 init 函数,使用绑定模块调用
fn
。与
Module.init
不同,此函数返回一个具有签名(rngs, *args, **kwargs) -> variables
的新函数。 rngs 可以是 PRNGKeys 字典或单个`PRNGKey`
,这相当于传递一个带有名称为“params”的单个 PRNGKey 的字典。返回的 init 函数可以直接与 JAX 变换(如
jax.jit
)组合。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
**fn** – 应应用的函数。传递的第一个参数将是一个模块实例,该实例是具有变量和 RNG 绑定的
module
。**module** – 将用于将变量和 RNG 绑定的
Module
。作为fn
的第一个参数传递的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合都是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为 True,则将所有模块内部的中间返回值捕获到“intermediates”集合中。 默认情况下,仅存储所有 __call__ 方法的返回值。 可以传递一个函数来更改过滤器行为。 过滤器函数接受模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 init 函数。
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
创建一个 init 函数,使用绑定模块调用
fn
,并返回函数输出。与
Module.init_with_output
不同,此函数返回一个具有签名(rngs, *args, **kwargs) -> (T, variables)
的新函数,其中T
是fn
的返回类型。 rngs 可以是 PRNGKeys 字典或单个`PRNGKey`
,这相当于传递一个带有名称为“params”的单个 PRNGKey 的字典。返回的 init 函数可以直接与 JAX 变换(如
jax.jit
)组合。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
**fn** – 应应用的函数。传递的第一个参数将是一个模块实例,该实例是具有变量和 RNG 绑定的
module
。**module** – 将用于将变量和 RNG 绑定的
Module
。作为fn
的第一个参数传递的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:
bool
:所有集合/无集合都是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。**capture_intermediates** – 如果为
True
,则捕获所有模块在“intermediates”集合中的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数获取模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 init 函数。
- flax.linen.intercept_methods(interceptor)[source]#
注册新的方法拦截器。
方法拦截器允许您(从远处)拦截对模块的 方法调用。 它类似于装饰器。 您可以在调用底层方法之前修改 args/kwargs,或者修改从调用底层方法返回的结果。 或者您也可以完全跳过调用底层方法,并决定执行其他操作。 例如
>>> import flax.linen as nn >>> import jax.numpy as jnp ... >>> class Foo(nn.Module): ... def __call__(self, x): ... return x ... >>> def my_interceptor1(next_fun, args, kwargs, context): ... print('calling my_interceptor1') ... return next_fun(*args, **kwargs) ... >>> foo = Foo() >>> with nn.intercept_methods(my_interceptor1): ... _ = foo(jnp.ones([1])) calling my_interceptor1
您也可以在同一方法上注册多个拦截器。 拦截器将按顺序运行。 例如
>>> def my_interceptor2(next_fun, args, kwargs, context): ... print('calling my_interceptor2') ... return next_fun(*args, **kwargs) ... >>> with nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor1 calling my_interceptor2
您可以通过直接调用
context.orig_method
来跳过其他拦截器。 例如>>> def my_interceptor3(next_fun, args, kwargs, context): ... print('calling my_interceptor3') ... return context.orig_method(*args, **kwargs) >>> with nn.intercept_methods(my_interceptor3), \ ... nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor3
以下方法无法拦截
使用
nn.nowrap
装饰的方法。双下划线方法,包括
__eq__
、__repr__
、__init__
、__hash__
和__post_init__
。模块数据类字段。
模块描述符。
- 参数
interceptor – 方法拦截器。
修改其中一个模块,使它们共享相同的范围。 当您想要包装模块并扩展其功能而不改变参数结构时,这很有用。
share_scope
接受两个模块,module
和other
。 如果other
具有范围,并且其不是 ``module`` 范围的子级,则module
将使用other
的范围。>>> import flax.linen as nn >>> import jax >>> from jax import numpy as jnp, random ... >>> class DenseLoRA(nn.Module): ... base: nn.Dense ... rank: int ... ... def setup(self): ... nn.share_scope(self, self.base) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.base.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.base(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... dense = nn.Dense(10) # base scope ... return DenseLoRA(dense, rank=2)(x) # reuse the base scope ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['Dense_0'].keys()) ['A', 'B', 'kernel', 'bias']
当
other
的范围是module
范围的子级时,other
将改为使用module
的范围。>>> class DenseLoRA(nn.Module): ... features: int ... rank: int ... ... def setup(self): ... self.child = nn.Dense(self.features) ... nn.share_scope(self, self.child) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.child(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... return DenseLoRA(10, rank=2)(x) ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['DenseLoRA_0'].keys()) ['A', 'B', 'kernel', 'bias']