Module#

Flax 模块系统。

class flax.linen.Module[source]#

所有神经网络模块的基类。

层和模型应该继承此类。

所有 Flax 模块都是 Python 3.7 数据类。由于数据类接管了 __init__,因此您应该改写 setup(),它会自动被调用以初始化模块。

模块可以包含子模块,并且可以通过这种方式嵌套在树状结构中。子模块可以作为 setup() 方法内的常规属性进行分配。

您可以在模块子类上定义任意“前向传递”方法。虽然没有方法是特殊处理的,但 __call__ 是一个常用的选择,因为它允许您将模块实例用作函数。

>>> from flax import linen as nn
>>> from typing import Tuple

>>> class Module(nn.Module):
...   features: Tuple[int, ...] = (16, 4)

...   def setup(self):
...     self.dense1 = nn.Dense(self.features[0])
...     self.dense2 = nn.Dense(self.features[1])

...   def __call__(self, x):
...     return self.dense2(nn.relu(self.dense1(x)))

可选地,对于子模块定义与其使用方式位于同一位置的更简洁的模块实现,您可以使用 compact() 包装器。

__setattr__(name, val)[source]#

在该模块上设置一个属性。

我们重载 setattr 仅仅是为了通过在特殊的 setup() 函数中分配子模块来支持 Python 命名

self.submodule_name = MyModule(...)

我们还支持列表和其他通用 PyTree,例如:

self.submodules = [MyModule0(..), MyModule1(..), ...]
参数
  • name – 要设置的属性。

  • val – 属性的值。

apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)[source]#

将模块方法应用于变量,并返回输出和修改后的变量。

请注意,如果要对不同于 __call__ 的类方法调用 apply,则应设置 method。例如,假设 Transformer 模块有一个名为 encode 的方法,那么以下调用将在该方法上调用 apply

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Transformer(nn.Module):
...   def encode(self, x):
...     ...

>>> x = jnp.ones((16, 9))
>>> model = Transformer()
>>> variables = model.init(jax.random.key(0), x, method=Transformer.encode)

>>> encoded = model.apply(variables, x, method=Transformer.encode)

如果提供函数实例,则使用未绑定的函数。例如,以下示例等效于上面的示例

>>> encoded = model.apply(variables, x, method=model.encode)

您还可以将字符串传递给模块的可调用属性。例如,前面的示例可以写成

>>> encoded = model.apply(variables, x, method='encode')

请注意 method 也可以是一个在 Transformer 中未定义的函数。在这种情况下,该函数至少应有一个表示模块类的实例的参数

>>> def other_fn(instance, x):
...   # instance.some_module_attr(...)
...   instance.encode
...   ...

>>> model.apply(variables, x, method=other_fn)

如果传递单个 PRNGKey,Flax 将使用它来馈送 'params' RNG 流。如果您想使用其他 RNG 流或需要使用多个流,则可以将字典传递给 apply,该字典将每个 RNG 流名称映射到其对应的 PRNGKey。如果在用户未传递的 RNG 流名称上调用了 self.make_rng(name),它将默认使用 'params' RNG 流。

示例

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, add_noise=False):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     if add_noise:
...       # Add gaussian noise
...       noise_key = self.make_rng('noise')
...       x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, x)
>>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs)

>>> rngs['noise'] = jax.random.key(0)
>>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # different output (key(1) vs key(0))
>>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1)

>>> del rngs['noise']
>>> # self.make_rng('noise') will default to using the 'params' RNG stream
>>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # same output (key(0))
>>> np.testing.assert_allclose(out1, out2)

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0))
>>> # same output (key(0))
>>> np.testing.assert_allclose(out2, out3)
参数
  • variables – 包含变量的字典,按变量集合进行键控。有关变量的更多详细信息,请参阅 flax.core.variables

  • *args – 传递给指定应用方法的命名参数。

  • rngs – 一个 PRNGKey 字典,用于初始化 PRNG 序列。“params” PRNG 序列用于初始化参数。

  • method – 要在上面调用应用的函数。这通常是模块中的函数。如果提供,则应用此方法。如果没有提供,则应用模块的 __call__ 方法。还可以提供字符串来按名称指定方法。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。

  • capture_intermediates – 如果为 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,只存储所有 __call__ 方法的返回值。可以传递函数来更改过滤器行为。过滤器函数接受模块实例和方法名称,并返回一个 bool 值,表示是否应存储该方法调用的输出。

  • **kwargs – 传递给指定应用方法的关键字参数。

返回

如果 mutable 为 False,则返回输出。如果任何集合是可变的,则返回 (output, vars),其中 vars 是修改后的集合的字典。

bind(variables, *args, rngs=None, mutable=False)[source]#

通过绑定变量和 RNG 来创建交互式模块实例。

bind 直接提供模块的“交互式”实例,而无需使用 apply 来转换函数。这对于调试和笔记本等交互式用例特别有用,在这些用例中,函数会限制将代码分成不同单元格的能力。

一旦变量(以及可选的 RNG)绑定到 Module,它就会变成一个有状态的对象。请注意,习惯性的 JAX 是函数式的,因此交互式实例与普通 JAX API 不太协调。 bind() 仅应用于交互式实验,在所有其他情况下,我们强烈建议用户使用 apply() 代替。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = nn.Dense(3)
...     self.decoder = nn.Dense(5)
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> x = jnp.ones((16, 9))
>>> ae = AutoEncoder()
>>> variables = ae.init(jax.random.key(0), x)
>>> model = ae.bind(variables)
>>> z = model.encoder(x)
>>> x_reconstructed = model.decoder(z)
参数
  • variables – 包含变量的字典,按变量集合进行键控。有关变量的更多详细信息,请参阅 flax.core.variables

  • *args – 命名参数(未使用)。

  • rngs – 一个 PRNGKey 字典,用于初始化 PRNG 序列。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。

返回

具有绑定变量和 RNG 的该实例的副本。

copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)[source]#

创建此模块的副本,可以选择更新参数。

参数
  • parent – 副本的父级。默认情况下,如果未明确指定,则当前模块将作为父级。

  • name – 复制模块的新名称,默认情况下将提供一个新的自动名称。

  • **updates – 属性更新。

返回

具有更新的名称、父级和属性的此模块的副本。

get_variable(col, name, default=None)[source]#

检索变量的值。

参数
  • col – 变量集合。

  • name – 变量的名称。

  • default – 如果变量不存在于此范围内,则要返回的默认值。

返回

输入变量的值,如果变量不存在于此范围内,则为默认值。

has_rng(name)[source]#

如果存在名为 name 的 PRNGSequence,则返回 true。

has_variable(col, name)[source]#

检查此模块中是否存在给定集合和名称的变量。

有关变量和集合的更多说明,请参见 flax.core.variables

参数
  • col – 变量集合名称。

  • name – 变量的名称。

返回

如果变量存在,则为 True。

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#

使用变量初始化模块方法并返回修改后的变量。

init 的第一个参数是单个 PRNGKey,或者是一个字典,它将变量集合名称映射到它们的 PRNGKeys,并将调用 method(默认情况下是模块的 __call__ 函数),传递 *args**kwargs,并返回一个初始化变量的字典。

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, x, train=False)

如果传递单个 PRNGKey,Flax 将使用它来馈送 'params' RNG 流。如果您想使用不同的 RNG 流或需要使用多个流,可以将一个字典传递给 init,它将每个 RNG 流名称映射到其对应的 PRNGKey。如果对用户未传递的 RNG 流名称调用 self.make_rng(name),它将默认使用 'params' RNG 流。

示例

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     other_variable = self.variable(
...       'other_collection',
...       'other_variable',
...       lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape),
...       x,
...     )
...     x = x + other_variable.value
...
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)}
>>> variables0 = module.init(rngs, x)

>>> rngs['other_rng'] = jax.random.key(0)
>>> variables1 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables0['params'], variables1['params']
... )
>>> # different other_variable (key(1) vs key(0))
>>> np.testing.assert_raises(
...   AssertionError,
...   np.testing.assert_allclose,
...   variables0['other_collection']['other_variable'],
...   variables1['other_collection']['other_variable'],
... )

>>> del rngs['other_rng']
>>> # self.make_rng('other_rng') will default to using the 'params' RNG stream
>>> variables2 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables1['params'], variables2['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables1['other_collection']['other_variable'],
...   variables2['other_collection']['other_variable'],
... )

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> variables3 = module.init(jax.random.key(0), x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables2['params'], variables3['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables2['other_collection']['other_variable'],
...   variables3['other_collection']['other_variable'],
... )

Jitting init 使用提供的参数的形状以延迟方式初始化模型,并避免使用实际值计算前向传递。示例

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), x)

initapply 的一个轻量级包装器,因此其他 apply 参数(如 methodmutablecapture_intermediates)也可用。

参数
  • rngs – 变量集合的 rng。

  • *args – 传递给 init 函数的命名参数。

  • method – 可选方法。如果提供,则应用此方法。如果未提供,则应用 __call__ 方法。也可以提供字符串来按名称指定方法。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合都是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。

  • capture_intermediates – 如果为 True,则在“intermediates”集合中捕获所有模块内部的中间返回值。默认情况下,只存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数接受模块实例和方法名称,并返回一个 bool 值,指示该方法调用输出是否应存储。

  • **kwargs – 传递给 init 函数的关键字参数。

返回

初始化的变量字典。

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#

使用变量初始化模块方法并返回输出和修改后的变量。

参数
  • rngs – 变量集合的 rng。

  • *args – 传递给 init 函数的命名参数。

  • method – 可选方法。如果提供,则应用此方法。如果未提供,则应用 __call__ 方法。也可以提供字符串来按名称指定方法。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合都是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。

  • capture_intermediates – 如果为 True,则在“intermediates”集合中捕获所有模块内部的中间返回值。默认情况下,只存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数接受模块实例和方法名称,并返回一个 bool 值,指示该方法调用输出是否应存储。

  • **kwargs – 传递给 init 函数的关键字参数。

返回

(output, vars),其中 vars 是修改后的集合的字典。

is_initializing()[source]#

如果在 self.init(…) 或 nn.init(…)() 下运行,则返回 True。

这是一个帮助方法,用于处理简单初始化的常见情况,在这种情况下,我们希望在仅在 module.initnn.init 下调用时发生设置逻辑。对于更复杂的多分阶段初始化场景,最好测试特定变量集合的可变性,或者测试可能需要初始化的特定变量的存在。

is_mutable_collection(col)[source]#

如果集合 col 是可变的,则返回 true。

lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)[source]#

初始化模块,无需在实际输入上进行计算。

lazy_init 将初始化变量,而不会进行不必要的计算。输入数据应作为 jax.ShapeDtypeStruct 传递,它指定输入的形状和 dtype,但不提供任何具体数据。

示例

>>> model = nn.Dense(features=256)
>>> variables = model.lazy_init(
...     jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))

传递给 lazy_init 的 args 和 kwargs args 可以是具体值(jax 数组、标量、布尔值)和抽象值(ShapeDtypeStruct)的混合。具体值仅对于影响变量初始化的参数是必需的。例如,模型可能需要一个关键字参数来启用/禁用模型的一部分。在这种情况下,应传递一个显式值(True/Flase),否则 lazy_init 无法推断应初始化哪些变量。

参数
  • rngs – 变量集合的 rng。

  • *args – 传递给 init 函数的参数。

  • method – 可选方法。如果提供,则应用此方法。如果未提供,则应用 __call__ 方法。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合都是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。

  • **kwargs – 传递给 init 函数的关键字参数。

返回

初始化的变量字典。

make_rng(name='params')[source]#

从模块的给定 RNG 序列中返回一个新的 RNG 密钥。

新的 RNG 密钥是从上一个密钥中拆分的。因此,每次调用 make_rng 都会返回一个新的 RNG 密钥,同时仍然保证完全可重复性。

注意

如果传递了一个无效的名称(即用户在 .init.apply 中没有为此名称传递 RNG 密钥),则 name 将默认为 'params'

示例

>>> import jax
>>> import flax.linen as nn

>>> class ParamsModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('params')
>>> class OtherModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('other')

>>> key = jax.random.key(0)
>>> params_out, _ = ParamsModule().init_with_output({'params': key})
>>> # self.make_rng('other') will default to using the 'params' RNG stream
>>> other_out, _ = OtherModule().init_with_output({'params': key})
>>> assert params_out == other_out

阅读 Flax RNG 指南,了解有关 RNG 的更多信息:https://flax.org.cn/en/latest/guides/flax_fundamentals/rng_guide.html

参数

name – RNG 序列名称。

返回

新生成的 RNG 密钥。

module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)[source]#

返回一个字典,将模块路径映射到模块实例。

此方法具有相同的签名,并在内部调用 Module.init,但它不返回变量,而是返回一个字典,将模块路径映射到运行时使用的模块实例的无界副本。 module_paths 使用 jax.eval_shape 来运行前向计算,而不会消耗任何 FLOPs 或分配内存。

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> modules = Foo().module_paths(jax.random.key(0), x)
>>> print({
...     p: type(m).__name__ for p, m in modules.items()
... })
{'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
参数
  • rngs – 变量集合的 rngs,如传递给 Module.init

  • *args – 前向计算的参数。

  • show_repeated – 如果为 True,则将显示对同一模块的重复调用,否则仅显示第一次调用。默认值为 False

  • mutable – 可以是布尔值、字符串或列表。指定哪些集合应该被视为可变的:bool:所有/无集合都是可变的。 str:单个可变集合的名称。 list:可变集合的名称列表。默认情况下,除了 ‘intermediates’ 之外的所有集合都是可变的。

  • **kwargs – 传递给前向计算的关键字参数。

返回

一个字典,将模块路径映射到模块实例。

param(name, init_fn, *init_args, unbox=True, **init_kwargs)[source]#

声明并返回此模块中的参数。

参数是名为 “params” 的集合中的只读变量。有关变量的更多详细信息,请参阅 flax.core.variables

init_fn 的第一个参数被假定为 PRNG 密钥,该密钥会自动提供,无需使用 init_argsinit_kwargs 传递。

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     mean = self.param('mean', nn.initializers.lecun_normal(), x.shape)
...     ...
...     return x * mean
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}

在上例中,函数 lecun_normal 期待两个参数:keyshape,但只需明确提供 shapekey 会使用用于 params 的 PRNG 自动设置,该 PRNG 在使用 init() 初始化模块时传递。

参数
  • name – 参数名称。

  • init_fn – 将被调用以计算此变量初始值的函数。该函数将只在第一次在该模块中使用该参数时被调用。

  • *init_args – 传递给 init_fn 的位置参数。

  • unbox – 如果为 True,则 AxisMetadata 实例将被替换为其未装箱的值,请参阅 flax.nn.meta.unbox(默认值:True)。

  • **init_kwargs – 传递给 init_fn 的关键字参数。

返回

已初始化参数的值。如果参数已存在,则会抛出错误。

property path#

获取此模块的路径。顶级根模块具有空路径 ()。请注意,此方法只能用于具有有效范围的绑定模块。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class SubModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'SubModel path: {self.path}')
...     return x

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'Model path: {self.path}')
...     return SubModel()(x)

>>> model = Model()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 2)))
Model path: ()
SubModel path: ('SubModel_0',)
perturb(name, value, collection='perturbations')[source]#

向中间值添加一个零值变量(“扰动”)。

value 的梯度将与该扰动变量的梯度相同。因此,如果使用 params 和 perturbations 作为独立参数定义损失函数,则可以通过对 perturbation 参数运行 jax.grad 来获取 value 的中间梯度。

注意

这是一个实验性 API,可能会在以后针对性能和可用性进行调整。在其当前阶段,它会创建额外的虚拟变量,这些变量会占用额外的内存空间。仅将其用于调试训练中的梯度。

示例

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = self.perturb('dense3', x)
...     return nn.Dense(2)(x)

>>> def loss(variables, inputs, targets):
...   preds = model.apply(variables, inputs)
...   return jnp.square(preds - targets).mean()

>>> x = jnp.ones((2, 9))
>>> y = jnp.ones((2, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
>>> print(intm_grads['perturbations']['dense3'])
[[-1.456924   -0.44332537  0.02422847]
 [-1.456924   -0.44332537  0.02422847]]

如果未将 perturbations 传递给 apply,则 perturb 将表现为无操作,因此可以轻松地在不需要时禁用该行为

>>> model.apply(variables, x) # works as expected
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
>>> 'perturbations' not in intm_grads
True
put_variable(col, name, value)[source]#

如果给定变量是可变的,则更新其值,否则会抛出错误。

参数
  • col – 变量集合。

  • name – 变量的名称。

  • value – 变量的新值。

setup()[source]#

延迟初始化模块(类似于延迟 __init__)。

setup 会在绑定模块时在模块实例上延迟调用一次,在调用任何其他方法(如 __call__)之前,或在访问 setup 定义的 self 上的属性之前。

这种情况可能发生在三种情况下

  1. 在调用 apply()init()init_and_output() 时立即调用。

  2. 在通过在另一个模块的 setup 方法中将其分配给另一个模块的属性来为模块命名后(请参阅 __setattr__())。

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. 在使用 compact() 包装的方法内构造模块后,在调用另一个方法或访问 setup 定义的属性之前立即调用。

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#

将值存储在集合中。

集合可用于收集中间值,而无需在每个模块调用时明确传递容器。

如果目标集合不可变,则 sow 会表现为无操作,并返回 False

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     self.sow('intermediates', 'h', h)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(variables, x, mutable=['intermediates'])
>>> jax.tree.map(jnp.shape, state['intermediates'])
{'h': ((16, 4),)}

默认情况下,值将存储在一个元组中,每个存储的值都会附加到末尾。这样,当多次调用同一模块时,就可以跟踪所有中间值。或者,可以传递自定义的 init/reduce 函数。

>>> class Foo2(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     init_fn = lambda: 0
...     reduce_fn = lambda a, b: a + b
...     self.sow('intermediates', 'h', x,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     self.sow('intermediates', 'h', x * 2,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     return x

>>> x = jnp.ones((1, 1))
>>> model = Foo2()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(
...     variables, x, mutable=['intermediates'])
>>> print(state['intermediates'])
{'h': Array([[3.]], dtype=float32)}
参数
  • col – 变量集合的名称。

  • name – 变量的名称。

  • value – 变量的值。

  • reduce_fn – 用于将现有值与新值结合的函数。默认情况下,将值附加到元组。

  • init_fn – 对于存储的第一个值,reduce_fn 将会传递 init_fn 的结果以及要存储的值。默认值为一个空元组。

返回

True 如果值已成功存储,否则为 False

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#

创建以表格形式表示的模块摘要。

此方法具有与 Module.init 相同的签名,并且内部调用它,但是它不返回变量,而是返回以表格形式汇总模块的字符串。 tabulate 使用 jax.eval_shape 来运行前向计算,而不会消耗任何 FLOP 或分配内存。

可以将其他参数传递到 console_kwargs 参数中,例如,{'width': 120}。有关 console_kwargs 参数的完整列表,请参见:https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))

>>> # print(Foo().tabulate(
>>> #     jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))

这将产生以下输出

                                      Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                              Total Parameters: 50 (200 B)

**注意**:表格中的行顺序不代表执行顺序,而是与 variables 中按字母顺序排序的键的顺序一致。

**注意**:如果模块不可微分,则 vjp_flops 将返回 0

参数
  • rngs – 变量集合的 rngs,如传递给 Module.init

  • *args – 前向计算的参数。

  • **depth** – 控制摘要可以深入到多少个子模块。默认情况下,它是 None,这意味着没有限制。如果子模块由于深度限制而未显示,则其参数计数和字节将添加到其第一个显示祖先的行,以使所有行的总和始终加起来等于模块的总参数数量。

  • show_repeated – 如果为 True,则将显示对同一模块的重复调用,否则仅显示第一次调用。默认值为 False

  • mutable – 可以是布尔值、字符串或列表。指定哪些集合应该被视为可变的:bool:所有/无集合都是可变的。 str:单个可变集合的名称。 list:可变集合的名称列表。默认情况下,除了 ‘intermediates’ 之外的所有集合都是可变的。

  • **console_kwargs** – 一个可选字典,包含传递给 rich.console.Console 的其他关键字参数,用于渲染表格。默认参数为 {'force_terminal': True, 'force_jupyter': False}

  • **table_kwargs** – 一个可选字典,包含传递给 rich.table.Table 构造函数的其他关键字参数。

  • **column_kwargs** – 一个可选字典,包含传递给 rich.table.Table.add_column 的其他关键字参数,用于向表格中添加列。

  • **compute_flops** – 是否在表格中包含一个 flops 列,列出每个模块前向传递的估计 FLOP 成本。确实会产生实际的设备上计算/编译/内存分配,但对于大型模块仍会引入开销(例如,Stable Diffusion 的 UNet 额外需要 20 秒,而其他情况下制表将在 5 秒内完成)。

  • **compute_vjp_flops** – 是否在表格中包含一个 vjp_flops 列,列出每个模块反向传递的估计 FLOP 成本。引入的计算开销约为 compute_flops 的 2-3 倍。

  • **kwargs – 传递给前向计算的关键字参数。

返回

一个总结模块的字符串。

unbind()[source]#

返回一个模块及其变量的未绑定副本。

unbind 有助于创建绑定模块的无状态版本。

一个常见用例的示例:提取在 setup() 中定义的子模块及其相应的变量:1)临时 bind 父模块;然后 2) unbind 所需的子模块。(请记住,setup() 仅在绑定模块时才会被调用。)

>>> class Encoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(256)(x)

>>> class Decoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(784)(x)

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = Encoder()
...     self.decoder = Decoder()
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> module = AutoEncoder()
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 784)))

>>> # Extract the Encoder sub-Module and its variables
>>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
返回

一个元组,其中包含该模块及其变量的未绑定副本。

variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)[source]#

在此模块中声明并返回一个变量。

有关更多信息,请参见 flax.core.variables。另请参见 param(),了解在“params”集合中定义只读变量的简写方式。

param() 不同,使用 init_fn 传递的所有参数都应明确传递。

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     key = self.make_rng('stats')
...     mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape)
...     ...
...     return x * mean.value
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}

在上面的示例中,函数 lecun_normal 预期两个参数:keyshape,并且两者都必须传递。当调用 init()apply() 时,必须显式提供 stats 的 PRNG。

参数
  • col – 变量集合名称。

  • **name** – 变量名称。

  • **init_fn** – 将被调用以计算此变量初始值的函数。此函数仅会在此模块中第一次使用此变量时被调用。如果为 None,则变量必须已初始化,否则将引发错误。

  • *init_args – 传递给 init_fn 的位置参数。

  • unbox – 如果为 True,则 AxisMetadata 实例将被替换为其未装箱的值,请参阅 flax.nn.meta.unbox(默认值:True)。

  • **init_kwargs** – 要传递给 init_fn 的关键字参数。

返回

一个 flax.core.variables.Variable,可以通过“”.value”属性进行读取或设置。如果变量已存在,则会抛出错误。

property variables#

返回此模块中的变量。

flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#

创建一个应用函数,以将绑定模块与 fn 一起调用。

Module.apply 不同,此函数返回一个新函数,其签名为 (variables, *args, rngs=None, **kwargs) -> T,其中 Tfn 的返回类型。如果 mutable 不是 False,则返回类型是一个元组,其中第二个项目是包含已变异变量的 FrozenDict

返回的应用函数可以直接与 JAX 变换组合,例如 jax.jit

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> variables = {}
>>> foo = Foo()
>>> f_jitted = jax.jit(nn.apply(f, foo))
>>> f_jitted(variables, jnp.ones((1, 3)))
参数
  • **fn** – 应应用的函数。传递的第一个参数将是一个模块实例,该实例是具有变量和 RNG 绑定的 module

  • **module** – 将用于将变量和 RNG 绑定的 Module。作为 fn 的第一个参数传递的 Module 将是 module 的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。

  • **capture_intermediates** – 如果为 True,则捕获所有模块在“intermediates”集合中的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数获取模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回

包装 fn 的应用函数。

flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

创建一个 init 函数,使用绑定模块调用 fn

Module.init 不同,此函数返回一个具有签名 (rngs, *args, **kwargs) -> variables 的新函数。 rngs 可以是 PRNGKeys 字典或单个 `PRNGKey`,这相当于传递一个带有名称为“params”的单个 PRNGKey 的字典。

返回的 init 函数可以直接与 JAX 变换(如 jax.jit)组合。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init(f, foo))
>>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
参数
  • **fn** – 应应用的函数。传递的第一个参数将是一个模块实例,该实例是具有变量和 RNG 绑定的 module

  • **module** – 将用于将变量和 RNG 绑定的 Module。作为 fn 的第一个参数传递的 Module 将是 module 的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合都是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。

  • capture_intermediates – 如果为 True,则将所有模块内部的中间返回值捕获到“intermediates”集合中。 默认情况下,仅存储所有 __call__ 方法的返回值。 可以传递一个函数来更改过滤器行为。 过滤器函数接受模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回

包装 fn 的 init 函数。

flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

创建一个 init 函数,使用绑定模块调用 fn,并返回函数输出。

Module.init_with_output 不同,此函数返回一个具有签名 (rngs, *args, **kwargs) -> (T, variables) 的新函数,其中 Tfn 的返回类型。 rngs 可以是 PRNGKeys 字典或单个 `PRNGKey`,这相当于传递一个带有名称为“params”的单个 PRNGKey 的字典。

返回的 init 函数可以直接与 JAX 变换(如 jax.jit)组合。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init_with_output(f, foo))
>>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
参数
  • **fn** – 应应用的函数。传递的第一个参数将是一个模块实例,该实例是具有变量和 RNG 绑定的 module

  • **module** – 将用于将变量和 RNG 绑定的 Module。作为 fn 的第一个参数传递的 Module 将是 module 的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变的:bool:所有集合/无集合都是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。

  • **capture_intermediates** – 如果为 True,则捕获所有模块在“intermediates”集合中的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数获取模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回

包装 fn 的 init 函数。

flax.linen.intercept_methods(interceptor)[source]#

注册新的方法拦截器。

方法拦截器允许您(从远处)拦截对模块的 方法调用。 它类似于装饰器。 您可以在调用底层方法之前修改 args/kwargs,或者修改从调用底层方法返回的结果。 或者您也可以完全跳过调用底层方法,并决定执行其他操作。 例如

>>> import flax.linen as nn
>>> import jax.numpy as jnp
...
>>> class Foo(nn.Module):
...   def __call__(self, x):
...     return x
...
>>> def my_interceptor1(next_fun, args, kwargs, context):
...   print('calling my_interceptor1')
...   return next_fun(*args, **kwargs)
...
>>> foo = Foo()
>>> with nn.intercept_methods(my_interceptor1):
...   _ = foo(jnp.ones([1]))
calling my_interceptor1

您也可以在同一方法上注册多个拦截器。 拦截器将按顺序运行。 例如

>>> def my_interceptor2(next_fun, args, kwargs, context):
...   print('calling my_interceptor2')
...   return next_fun(*args, **kwargs)
...
>>> with nn.intercept_methods(my_interceptor1), \
...      nn.intercept_methods(my_interceptor2):
...   _ = foo(jnp.ones([1]))
calling my_interceptor1
calling my_interceptor2

您可以通过直接调用 context.orig_method 来跳过其他拦截器。 例如

>>> def my_interceptor3(next_fun, args, kwargs, context):
...   print('calling my_interceptor3')
...   return context.orig_method(*args, **kwargs)
>>> with nn.intercept_methods(my_interceptor3), \
...      nn.intercept_methods(my_interceptor1), \
...      nn.intercept_methods(my_interceptor2):
...   _ = foo(jnp.ones([1]))
calling my_interceptor3

以下方法无法拦截

  1. 使用 nn.nowrap 装饰的方法。

  2. 双下划线方法,包括 __eq____repr____init____hash____post_init__

  3. 模块数据类字段。

  4. 模块描述符。

参数

interceptor – 方法拦截器。

flax.linen.share_scope(module, other, /)[source]#

修改其中一个模块,使它们共享相同的范围。 当您想要包装模块并扩展其功能而不改变参数结构时,这很有用。

share_scope 接受两个模块,moduleother。 如果 other 具有范围,并且其不是 ``module`` 范围的子级,则 module 将使用 other 的范围。

>>> import flax.linen as nn
>>> import jax
>>> from jax import numpy as jnp, random
...
>>> class DenseLoRA(nn.Module):
...   base: nn.Dense
...   rank: int
...
...   def setup(self):
...     nn.share_scope(self, self.base)
...
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     din, dout = x.shape[-1], self.base.features
...     A = self.param('A', nn.zeros_init(), (din, self.rank))
...     B = self.param('B', nn.zeros_init(), (self.rank, dout))
...     return self.base(x) + x @ A @ B
...
>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     dense = nn.Dense(10) # base scope
...     return DenseLoRA(dense, rank=2)(x) # reuse the base scope
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['Dense_0'].keys())
['A', 'B', 'kernel', 'bias']

other 的范围是 module 范围的子级时,other 将改为使用 module 的范围。

>>> class DenseLoRA(nn.Module):
...   features: int
...   rank: int
...
...   def setup(self):
...     self.child = nn.Dense(self.features)
...     nn.share_scope(self, self.child)
...
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     din, dout = x.shape[-1], self.features
...     A = self.param('A', nn.zeros_init(), (din, self.rank))
...     B = self.param('B', nn.zeros_init(), (self.rank, dout))
...     return self.child(x) + x @ A @ B
...
>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     return DenseLoRA(10, rank=2)(x)
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['DenseLoRA_0'].keys())
['A', 'B', 'kernel', 'bias']