内容

#

线性模块#

class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

应用于输入最后一维度的线性变换。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Dense(features=4)
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4,), 'kernel': (3, 4)}}
features#

输出特征的数量。

类型

int

use_bias#

是否在输出中添加偏差(默认:True)。

类型

bool

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

权重矩阵的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)[source]#

对输入沿着最后一个维度应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

具有灵活轴的线性变换。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # equivalent to `nn.Dense(features=4)`
>>> layer = nn.DenseGeneral(features=4)
>>> # output features (4, 5)
>>> layer = nn.DenseGeneral(features=(4, 5))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}}
>>> # apply transformation on the the second and last axes
>>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
features#

输出特征数量的整型或元组。

类型

int | collections.abc.Sequence[int]

axis#

要应用变换的轴的整型或元组。例如,(-2, -1) 将把变换应用于最后两个轴。

类型

int | collections.abc.Sequence[int]

batch_dims#

包含批次轴的元组。

类型

collections.abc.Sequence[int]

use_bias#

是否在输出中添加偏差(默认:True)。

类型

bool

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

kernel_init#

权重矩阵的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

__call__(inputs)[source]#

对输入沿着多个维度应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.linen.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

包装 `lax.conv_general_dilated` 的卷积模块。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # valid padding
>>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}
>>> out.shape
(1, 6, 4)
>>> # circular padding with stride 2
>>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}}
>>> out.shape
(1, 4, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。单个整数将被解释为单个整数的元组。

类型

int | collections.abc.Sequence[int]

strides#

一个整数或一个 `n` 个整数的序列,表示窗口间步长(默认值:1)。

类型

None | int | collections.abc.Sequence[int]

padding#

字符串 `'SAME'`,字符串 `'VALID'`,字符串 `'CIRCULAR'`(周期性边界条件)之一,或者一个 `n` 个 `(low, high)` 整数对的序列,这些整数对给出要对每个空间维度应用的填充。单个 int 被解释为在所有维度上应用相同的填充,并在序列中分配单个 int 会导致在两侧使用相同的填充。对于一维卷积,`'CAUSAL'` 填充将对卷积轴进行左填充,从而产生相同大小的输出。

类型

Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]

input_dilation#

一个整数或一个 `n` 个整数的序列,给出要对 `inputs` 的每个空间维度应用的膨胀因子(默认值:1)。具有输入膨胀 `d` 的卷积等效于具有步长 `d` 的转置卷积。

类型

None | int | collections.abc.Sequence[int]

kernel_dilation#

一个整数或一个 `n` 个整数的序列,给出要对卷积核的每个空间维度应用的膨胀因子(默认值:1)。具有核膨胀的卷积也称为“空洞卷积”。

类型

None | int | collections.abc.Sequence[int]

feature_group_count#

整数,默认值为 1。如果指定,则将输入特征划分为组。

类型

int

use_bias#

是否在输出中添加偏差(默认:True)。

类型

bool

mask#

在掩码卷积期间,权重的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。

类型

Optional[Union[jax.Array, Any]]

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,请参阅 ``jax.lax.Precision` 获取详细信息。

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

卷积核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)#

对输入应用(可能不共享的)卷积。

参数

inputs – 具有维度 `(*batch_dims, spatial_dims..., features)` 的输入数据。这是 channels-last 约定,即二维卷积的 NHWC 和三维卷积的 NDHWC。注意:这与 `lax.conv_general_dilated` 使用的输入约定不同,它将空间维度放在最后。注意:如果输入具有多个批次维度,所有批次维度都将被展平成单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 层可能比这种默认展平方法产生更好的性能。如果输入缺少批次维度,它将被添加到卷积中,并在返回时删除,这是一个允许编写单示例代码的允许值。

返回

卷积后的数据。

方法

class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

封装了 lax.conv_transpose 的卷积模块。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # valid padding
>>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}
>>> out.shape
(1, 10, 4)
>>> # circular padding with stride 2
>>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True)
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}}
>>> out.shape
(1, 30, 30, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。对于一维卷积,核大小可以作为整数传递,将被解释为单个整数的元组。对于所有其他情况,它必须是整数序列。

类型

int | collections.abc.Sequence[int]

strides#

一个整数或一个 n 个整数的序列,表示窗口间步长。

类型

collections.abc.Sequence[int] | None

padding#

字符串 ‘SAME’、字符串 ‘VALID’、字符串 ‘CIRCULAR’(周期性边界条件)或 n(low, high) 整数对的序列,这些对给出要应用于每个空间维度的填充。单个 int 被解释为在所有维度中应用相同的填充,并在序列中分配单个 int 导致在两侧使用相同的填充。

类型

Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]

kernel_dilation#

None 或一个整数或 n 个整数的序列,给出要应用于卷积核每个空间维度的膨胀因子。具有内核膨胀的卷积也被称为“空洞卷积”。

类型

collections.abc.Sequence[int] | None

use_bias#

是否在输出中添加偏差(默认:True)。

类型

bool

mask#

在掩码卷积期间,权重的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。

类型

Optional[Union[jax.Array, Any]]

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

卷积核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

transpose_kernel#

如果 True 翻转空间轴并交换内核的输入/输出通道轴。

类型

bool

__call__(inputs)[source]#

对输入应用转置卷积。

行为镜像 jax.lax.conv_transpose

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features). 这是通道最后约定,即二维卷积的 NHWC 和三维卷积的 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度将被展平为单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接对层进行 vmap 可能比这种默认展平方法产生更好的性能。如果输入缺少批次维度,则将为卷积添加它,并在返回时删除,允许编写单个示例代码。

返回

卷积后的数据。

方法

class flax.linen.ConvLocal(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

封装了 lax.conv_general_dilated_local 的局部卷积模块。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # valid padding
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}}
>>> out.shape
(1, 6, 4)
>>> # circular padding with stride 2
>>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}}
>>> out.shape
(1, 4, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((6, 9, 4)))
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。单个整数将被解释为单个整数的元组。

类型

int | collections.abc.Sequence[int]

strides#

一个整数或一个 `n` 个整数的序列,表示窗口间步长(默认值:1)。

类型

None | int | collections.abc.Sequence[int]

padding#

字符串 `'SAME'`,字符串 `'VALID'`,字符串 `'CIRCULAR'`(周期性边界条件)之一,或者一个 `n` 个 `(low, high)` 整数对的序列,这些整数对给出要对每个空间维度应用的填充。单个 int 被解释为在所有维度上应用相同的填充,并在序列中分配单个 int 会导致在两侧使用相同的填充。对于一维卷积,`'CAUSAL'` 填充将对卷积轴进行左填充,从而产生相同大小的输出。

类型

Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]

input_dilation#

一个整数或一个 `n` 个整数的序列,给出要对 `inputs` 的每个空间维度应用的膨胀因子(默认值:1)。具有输入膨胀 `d` 的卷积等效于具有步长 `d` 的转置卷积。

类型

None | int | collections.abc.Sequence[int]

kernel_dilation#

一个整数或一个 `n` 个整数的序列,给出要对卷积核的每个空间维度应用的膨胀因子(默认值:1)。具有核膨胀的卷积也称为“空洞卷积”。

类型

None | int | collections.abc.Sequence[int]

feature_group_count#

整数,默认值为 1。如果指定,则将输入特征划分为组。

类型

int

use_bias#

是否在输出中添加偏差(默认:True)。

类型

bool

mask#

在掩码卷积期间,权重的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。

类型

Optional[Union[jax.Array, Any]]

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

卷积核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)#

对输入应用(可能不共享的)卷积。

参数

inputs – 具有维度 `(*batch_dims, spatial_dims..., features)` 的输入数据。这是 channels-last 约定,即二维卷积的 NHWC 和三维卷积的 NDHWC。注意:这与 `lax.conv_general_dilated` 使用的输入约定不同,它将空间维度放在最后。注意:如果输入具有多个批次维度,所有批次维度都将被展平成单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 层可能比这种默认展平方法产生更好的性能。如果输入缺少批次维度,它将被添加到卷积中,并在返回时删除,这是一个允许编写单示例代码的允许值。

返回

卷积后的数据。

方法

class flax.linen.Einsum(shape, einsum_str=None, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

具有可学习内核和偏差的 einsum 变换。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Einsum((5, 6, 7), 'abc,cde->abde')
>>> variables = layer.init(jax.random.key(0), jnp.ones((3, 4, 5)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (6, 7), 'kernel': (5, 6, 7)}}
shape#

内核的形状。

类型

collections.abc.Sequence[int]

einsum_str#

用于表示 einsum 方程的字符串。该方程必须恰好有两个操作数,lhs 是传递的输入,rhs 是可学习的内核。构造函数参数和调用参数中的 einsum_str 必须有一个不为 None,而另一个必须为 None。

类型

str | None

use_bias#

是否在输出中添加偏差(默认:True)。

类型

bool

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

权重矩阵的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs, einsum_str=None)[source]#

对输入沿着最后一个维度应用线性变换。

参数
  • inputs – 要变换的 nd 数组。

  • einsum_str – 表示 einsum 方程式的字符串。该方程式必须正好有两个操作数,左侧是传入的输入,右侧是可学习的核。调用方法中传入的 einsum_str 将优先于构造函数中传入的 einsum_str

返回

变换后的输入。

方法

class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

嵌入模块。

从整数 [0, num_embeddings) 到 features 维向量的一个参数化函数。此 Module 将创建一个形状为 (num_embeddings, features)embedding 矩阵。当调用此层时,输入值将用于 0 索引到 embedding 矩阵中。对大于或等于 num_embeddings 的值的索引将导致 nan 值。当 num_embeddings 等于 1 时,它将广播 embedding 矩阵到输入形状,并附加 features 维度。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Embed(num_embeddings=5, features=3)
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> variables = layer.init(jax.random.key(0), indices_input)
>>> variables
{'params': {'embedding': Array([[-0.28884724,  0.19018005, -0.414205  ],
       [-0.11768015, -0.54618824, -0.3789283 ],
       [ 0.30428642,  0.49511626,  0.01706631],
       [-0.0982546 , -0.43055868,  0.20654906],
       [-0.688412  , -0.46882293,  0.26723292]], dtype=float32)}}
>>> # get the first three and last three embeddings
>>> layer.apply(variables, indices_input)
Array([[[-0.28884724,  0.19018005, -0.414205  ],
        [-0.11768015, -0.54618824, -0.3789283 ],
        [ 0.30428642,  0.49511626,  0.01706631]],

       [[-0.688412  , -0.46882293,  0.26723292],
        [-0.0982546 , -0.43055868,  0.20654906],
        [ 0.30428642,  0.49511626,  0.01706631]]], dtype=float32)
num_embeddings#

嵌入数量 / 词汇量大小。

类型

int

features#

每个嵌入的特征维数。

类型

int

dtype#

嵌入向量的 dtype(默认值:与嵌入相同)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

embedding_init#

嵌入初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)[source]#

沿最后一个维度嵌入输入。

参数

inputs – 输入数据,所有维度都被视为批处理维度。输入数组中的值必须是整数。

返回

输出是嵌入的输入数据。输出形状遵循输入,并在其后附加一个 features 维度。

attend(query)[source]#

使用查询数组对嵌入进行关注。

参数

query – 数组,其最后一个维度等于嵌入的特征深度 features

返回

一个数组,其最后一个维度为 num_embeddings,对应于查询向量数组与每个嵌入的批处理内积。通常用于 NLP 模型中嵌入和 logits 变换之间的权重共享。

方法

attend(query)

使用查询数组对嵌入进行关注。

池化#

flax.linen.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#

通过对窗口切片取最大值来池化输入。

参数
  • inputs – 输入数据,维度为 (batch, window dims…, features)。

  • window_shape – 一个形状元组,定义要减少的窗口。

  • strides – 一系列 n 个整数,表示窗口间步长(默认值:(1, ..., 1))。

  • padding – 既可以是字符串 'SAME',也可以是字符串 'VALID',或者是一系列 n(low, high) 整数对,用于给出要应用于每个空间维度的前后的填充(默认值:'VALID')。

返回

每个窗口切片的最大值。

flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#

通过对窗口取平均值来池化输入。

参数
  • inputs – 输入数据,维度为 (batch, window dims…, features)。

  • window_shape – 一个形状元组,定义要减少的窗口。

  • strides – 一系列 n 个整数,表示窗口间步长(默认值:(1, ..., 1))。

  • padding – 既可以是字符串 'SAME',也可以是字符串 'VALID',或者是一系列 n(low, high) 整数对,用于给出要应用于每个空间维度的前后的填充(默认值:'VALID')。

  • count_include_pad – 一个布尔值,表示是否将填充的标记包含在平均值计算中(默认值:True)。

返回

每个窗口切片的平均值。

flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#

用于定义池化函数的辅助函数。

池化函数是使用 ReduceWindow XLA 操作实现的。

注意

请注意,池化通常不可微。这意味着提供一个可微的 reduce_fn 并不意味着 pool 可微。

参数
  • inputs – 输入数据,维度为 (batch, window dims…, features)。

  • init – 减少的初始值

  • reduce_fn – 形式为 (T, T) -> T 的减少函数。

  • window_shape – 一个形状元组,定义要减少的窗口。

  • strides – 一系列 n 个整数,表示窗口间步长(默认值:(1, ..., 1))。

  • padding – 既可以是字符串 'SAME',也可以是字符串 'VALID',或者是一系列 n(low, high) 整数对,用于给出要应用于每个空间维度的前后的填充。

返回

每个窗口切片的减少输出。

规范化#

class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

BatchNorm 模块。

使用说明:如果我们定义一个带 BatchNorm 的模型,例如

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> BN = nn.BatchNorm(momentum=0.9, epsilon=1e-5, dtype=jnp.float32)

初始化的变量字典除了包含 ‘params’ 集合外,还会包含一个单独的 ‘batch_stats’ 集合,该集合将包含模型中所有 BatchNorm 层的所有运行统计数据

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> variables = BN.init(jax.random.key(1), x, use_running_average=False)
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'batch_stats': {'mean': (6,), 'var': (6,)}, 'params': {'bias': (6,), 'scale': (6,)}}

然后,我们在训练期间通过指定 batch_stats 集合在模块的 apply 方法中是可变的来更新 batch_stats。

>>> y, new_batch_stats = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=False)

在评估过程中,我们将使用 use_running_average=True 定义 BN,并使用来自训练的 batch_stats 集合来设置统计信息。在这种情况下,我们不会修改 batch statistics 集合,因此不需要将其标记为可变的。

>>> y = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=True)
use_running_average#

如果为 True,则将使用 batch_stats 中存储的统计信息,而不是计算输入的批次统计信息。

类型

bool | None

axis#

输入的特征或非批次轴。

类型

int

momentum#

批次统计信息的指数移动平均的衰减率。

类型

float

epsilon#

添加到方差中的一个小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差(beta)。

类型

bool

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认值为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认值为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅 jax.pmap(默认值:None)。请注意,这仅用于 pmap 和 shard map。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。

类型

str | None

axis_index_groups#

在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 True,则使用更快但数值稳定性较差的方差计算方法。

类型

bool

__call__(x, use_running_average=None, *, mask=None)[source]#

使用批次统计信息对输入进行归一化。

注意

在初始化期间(当 self.is_initializing()True 时),批次统计信息的运行平均值不会更新。因此,在初始化期间馈送的输入不需要与实际输入分布相匹配,并且约简轴(使用 axis_name 设置)不需要存在。

参数
  • x – 要归一化的输入。

  • use_running_average – 如果为 true,则将使用 batch_stats 中存储的统计信息,而不是计算输入的批次统计信息。

  • mask – 形状可广播到 inputs 张量的二元数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

层归一化 (https://arxiv.org/abs/1607.06450).

LayerNorm 对批次中每个给定示例的层激活进行归一化,而不是像批次归一化一样跨批次进行归一化。即应用变换,使每个示例内的平均激活保持接近 0,并且激活标准差保持接近 1。

注意

此归一化操作与 InstanceNorm 和 GroupNorm 相同;区别仅在于哪些轴被约简以及特征轴的形状(即可学习比例和偏差参数的形状)。

示例用法

>>> import flax.linen as nn
>>> import jax
>>> import numpy as np

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nn.LayerNorm()
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}
>>> y = layer.apply(variables, x)

>>> y = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x)
>>> y2 = nn.GroupNorm(num_groups=1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)

>>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x)
>>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)
epsilon#

添加到方差中的一个小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差(beta)。

类型

bool

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认值为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认值为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

reduction_axes#

用于计算归一化统计信息的轴。

类型

Union[int, collections.abc.Sequence[int]]

feature_axes#

用于学习偏差和缩放的特征轴。

类型

Union[int, collections.abc.Sequence[int]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅 jax.pmap(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。

类型

str | None

axis_index_groups#

在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 True,则使用更快但数值稳定性较差的方差计算方法。

类型

bool

__call__(x, *, mask=None)[source]#

对输入应用层归一化。

参数
  • x – 输入

  • mask – 形状可广播到 inputs 张量的二元数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

组归一化 (arxiv.org/abs/1803.08494).

该操作类似于批归一化,但统计信息在大小相等的通道组之间共享,而不是在批次维度之间共享。因此,组归一化不依赖于批次组成,也不需要维护内部状态来存储统计信息。用户应该指定通道组的总数或每个组的通道数。

注意

LayerNorm 是 GroupNorm 的特例,其中 num_groups=1,而 InstanceNorm 是 GroupNorm 的特例,其中 group_size=1

示例用法

>>> import flax.linen as nn
>>> import jax
>>> import numpy as np

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nn.GroupNorm(num_groups=3)
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}
>>> y = layer.apply(variables, x)

>>> y = nn.GroupNorm(num_groups=1).apply(variables, x)
>>> y2 = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)

>>> y = nn.GroupNorm(num_groups=None, group_size=1).apply(variables, x)
>>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)
num_groups#

通道组的总数。原始组归一化论文建议默认值为 32。

类型

int | None

group_size#

每个组的通道数。

类型

int | None

epsilon#

添加到方差中的一个小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差(beta)。

类型

bool

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认值为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认值为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

reduction_axes#

用于计算归一化统计信息的轴列表。此列表必须包含最后一个维度,该维度被假定为特征轴。此外,如果调用时使用的输入与用于初始化的数据相比具有额外的领先轴,例如由于批处理,则需要明确定义归约轴。

类型

Optional[Union[int, collections.abc.Sequence[int]]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅 jax.pmap(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。

类型

str | None

axis_index_groups#

在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 True,则使用更快但数值稳定性较差的方差计算方法。

类型

bool

__call__(x, *, mask=None)[source]#

将组归一化应用于输入 (arxiv.org/abs/1803.08494).

参数
  • x – 形状为 ...C 的输入,其中 C 是通道维度,而 ... 表示任意数量的额外维度,这些维度可用于在统计信息上累积。如果没有指定归约轴,则除了假定代表批次的第一个维度之外,所有额外维度 ... 将用于累积统计信息。

  • mask – 形状可广播到 inputs 张量的二元数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.RMSNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

RMS 层归一化 (https://arxiv.org/abs/1910.07467).

RMSNorm 独立地为批次中的每个示例规范化层的激活,而不是像批次归一化一样跨批次规范化。与 LayerNorm 将平均值重新居中为 0 并通过激活的标准差进行规范化不同,RMSNorm 根本不重新居中,而是通过激活的均方根进行规范化。

示例用法

>>> import flax.linen as nn
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nn.RMSNorm()
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}}
>>> y = layer.apply(variables, x)
epsilon#

添加到方差中的一个小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。

类型

bool

scale_init#

比例的初始化器,默认值为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

reduction_axes#

用于计算归一化统计信息的轴。

类型

Union[int, collections.abc.Sequence[int]]

feature_axes#

用于学习偏差和缩放的特征轴。

类型

Union[int, collections.abc.Sequence[int]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅 jax.pmap(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。

类型

str | None

axis_index_groups#

在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 True,则使用更快但数值稳定性较差的方差计算方法。

类型

bool

__call__(x, *, mask=None)[source]#

对输入应用 RMS 层归一化。

参数
  • x – 输入

  • mask – 形状可广播到 inputs 张量的二元数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.InstanceNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

实例归一化 (https://arxiv.org/abs/1607.08022v3).

InstanceNorm 为每个通道(而不是像 Layer Normalization 那样跨所有通道)以及独立地为批次中的每个示例(而不是像批次归一化那样跨整个批次)规范化层的激活。即,应用一个变换,使每个示例内每个通道内的平均激活值接近 0,而激活标准差接近 1。

注意

此归一化操作与 LayerNorm 和 GroupNorm 相同;唯一的区别在于归约哪些轴以及特征轴的形状(即可学习缩放和偏差参数的形状)。

示例用法

>>> import flax.linen as nn
>>> import jax
>>> import numpy as np

>>> # dimensions: (batch, height, width, channel)
>>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5))
>>> layer = nn.InstanceNorm()
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}
>>> y = layer.apply(variables, x)

>>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch,
>>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm
>>> y2 = nn.LayerNorm(reduction_axes=[1, 2], feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2, atol=1e-7)
>>> y3 = nn.GroupNorm(num_groups=x.shape[-1]).apply(variables, x)
>>> np.testing.assert_allclose(y, y3, atol=1e-7)
epsilon#

添加到方差中的一个小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差(beta)。

类型

bool

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认值为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认值为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

feature_axes#

特征轴。学习到的偏置和缩放参数将具有特征轴定义的形状。除了批处理轴(假设为前导轴)之外的所有其他轴都将被减少。

类型

Union[int, collections.abc.Sequence[int]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅 jax.pmap(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。

类型

str | None

axis_index_groups#

在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 True,则使用更快但数值稳定性较差的方差计算方法。

类型

bool

__call__(x, *, mask=None)[source]#

对输入应用实例归一化。

参数
  • x – 输入

  • mask – 形状可广播到 inputs 张量的二元数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

谱归一化。

谱归一化将权重参数归一化,以便矩阵的谱范数等于 1。这实现为一个层包装器,其中每个包装层在其计算其 __call__ 输出之前会对其参数进行谱归一化。

注意

初始化的变量字典除了包含一个“params”集合之外,还包含一个单独的“batch_stats”集合,该集合将包含一个 u 向量和 sigma 值,这些值是执行谱归一化时使用的中间值。在训练期间,我们传入 update_stats=Truemutable=['batch_stats'],以便使用幂迭代方法计算的最新值更新 usigma。这将有助于幂迭代方法随着时间的推移更准确地逼近真实的奇异值。在评估期间,我们传入 update_stats=False 以确保我们从模型中获得确定性行为。

示例用法

>>> import flax, flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import optax

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(3)(x)
...     # only spectral normalize the params of the second Dense layer
...     x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train)
...     x = nn.Dense(5)(x)
...     return x

>>> # init
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 5))
>>> model = Foo()
>>> variables = model.init(jax.random.PRNGKey(0), x, train=False)
>>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables))
FrozenDict({
    batch_stats: {
        SpectralNorm_0: {
            Dense_1/kernel/sigma: (),
            Dense_1/kernel/u: (1, 4),
        },
    },
    params: {
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (4,),
            kernel: (3, 4),
        },
        Dense_2: {
            bias: (5,),
            kernel: (4, 5),
        },
    },
})

>>> # train
>>> def train_step(variables, x, y):
...   def loss_fn(params):
...     logits, updates = model.apply(
...         {'params': params, 'batch_stats': variables['batch_stats']},
...         x,
...         train=True,
...         mutable=['batch_stats'],
...     )
...     loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y))
...     return loss, updates
...
...   (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)(
...       variables['params']
...   )
...   return {
...       'params': jax.tree_util.tree_map(
...           lambda p, g: p - 0.1 * g, variables['params'], grads
...       ),
...       'batch_stats': updates['batch_stats'],
...   }, loss
>>> for _ in range(10):
...   variables, loss = train_step(variables, x, y)

>>> # inference / eval
>>> out = model.apply(variables, x, train=False)
layer_instance#

用 SpectralNorm 包装的模块实例

类型

flax.linen.module.Module

n_steps#

执行多少步幂迭代以逼近权重参数的奇异值。

类型

int

epsilon#

添加到 l2 规范化中的一个小浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

error_on_non_matrix#

谱归一化仅针对矩阵定义。默认情况下,此模块将返回未改变的标量并将更高阶张量在其前导维度中展平。将此标志设置为 True 将在层使用维度大于 2 的权重张量时抛出错误。

类型

bool

collection_name#

存储执行谱归一化时使用的中间值的集合的名称。

类型

str

__call__(*args, update_stats, **kwargs)[source]#

使用幂迭代方法计算 self.layer_instance 中权重的最大奇异值,并使用此值归一化权重,然后计算 __call__ 输出。

参数
  • *args – 传递给 self.layer_instance 中底层层实例的调用方法的位置参数。

  • update_stats – 如果为 True,则在使用幂迭代方法计算其更新的值后更新内部 u 向量和 sigma 值。这将有助于幂迭代方法随着时间的推移更准确地逼近真实的奇异值。

  • **kwargs – 传递给 self.layer_instance 中底层层实例的调用方法的关键字参数。

返回

使用谱归一化权重的层的输出。

方法

class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

L2 权重规范化 (https://arxiv.org/abs/1602.07868).

权重规范化将权重参数归一化,以便矩阵的 l2 范数等于 1。这实现为一个层包装器,其中每个包装层在其计算其 __call__ 输出之前会对其参数进行 l2 归一化。

示例用法

>>> import flax, flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Baz(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     return nn.Dense(2)(x)

>>> class Bar(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = Baz()(x)
...     x = nn.Dense(3)(x)
...     x = Baz()(x)
...     x = nn.Dense(3)(x)
...     return x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     # l2-normalize all params of the second Dense layer
...     x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x)
...     x = nn.Dense(5)(x)
...     # l2-normalize all kernels in the Bar submodule and all params in
...     # the Baz submodule
...     x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x)
...     return x

>>> # init
>>> x = jnp.ones((1, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables))
FrozenDict({
    params: {
        Bar_0: {
            Baz_0: {
                Dense_0: {
                    bias: (2,),
                    kernel: (5, 2),
                },
            },
            Baz_1: {
                Dense_0: {
                    bias: (2,),
                    kernel: (3, 2),
                },
            },
            Dense_0: {
                bias: (3,),
                kernel: (2, 3),
            },
            Dense_1: {
                bias: (3,),
                kernel: (2, 3),
            },
        },
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (4,),
            kernel: (3, 4),
        },
        Dense_2: {
            bias: (5,),
            kernel: (4, 5),
        },
        WeightNorm_0: {
            Dense_1/bias/scale: (4,),
            Dense_1/kernel/scale: (4,),
        },
        WeightNorm_1: {
            Bar_0/Baz_0/Dense_0/bias/scale: (2,),
            Bar_0/Baz_0/Dense_0/kernel/scale: (2,),
            Bar_0/Baz_1/Dense_0/bias/scale: (2,),
            Bar_0/Baz_1/Dense_0/kernel/scale: (2,),
            Bar_0/Dense_0/kernel/scale: (3,),
            Bar_0/Dense_1/kernel/scale: (3,),
        },
    },
})
layer_instance#

用 WeightNorm 包装的模块实例

类型

flax.linen.module.Module

epsilon#

添加到 l2 规范化中的一个小浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_scale#

如果为 True,则创建一个可学习的变量 scale,该变量在 l2 规范化后乘以 layer_instance 变量。

类型

bool

scale_init#

缩放函数的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

feature_axes#

特征轴维度。l2 范数是通过在剩余的(非特征)轴上减少 layer_instance 变量来计算的。因此,为每个指定的特征计算一个单独的 l2 范数值,并学习一个单独的比例因子(如果 use_scale=True)。默认情况下,尾随维度被视为特征轴。

类型

Optional[Union[int, collections.abc.Sequence[int]]]

variable_filter#

一个可选的迭代器,它包含字符串项目。WeightNorm 层将选择性地将 l2 规范化应用于 layer_instance 变量,其键路径(由“/”分隔)与 variable_filter 相匹配。例如,variable_filter={'kernel'} 将仅将 l2 规范化应用于键路径包含“kernel”的变量。默认情况下,variable_filter={'kernel'}

类型

collections.abc.Iterable | None

__call__(*args, **kwargs)[source]#

计算 self.layer_instance 中权重的 l2 范数,并使用此值归一化权重,然后计算 __call__ 输出。

参数
  • *args – 传递给 self.layer_instance 中底层层实例的调用方法的位置参数。

  • **kwargs – 传递给 self.layer_instance 中底层层实例的调用方法的关键字参数。

返回

使用 l2 归一化权重的层的输出。

方法

组合器#

class flax.linen.Sequential(layers, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

应用一串线性模块。

仅用于将可调用对象融合在一起的简单情况,其中特定模块/操作的输入是前一个模块/操作的输出。

模块将按照它们在构造函数中传递的顺序应用。

Sequential 的 __call__ 方法接受任何输入并将其转发到它包含的第一个模块。 它将输出按顺序链接到下一个模块的输入,并返回最终模块的输出。

示例用法

>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     return nn.Sequential([nn.Dense(4),
...                           nn.relu,
...                           nn.Dense(2),
...                           nn.log_softmax])(x)

由于 Sequential.__call__ 是一个 紧凑 方法,如果需要形状推断,您也可以传递构建模块的内联函数

module = nn.Sequential([
    # << more layers
    lambda x: SomeModule(x.shape[-1])(x), # shape inference
    # << more layers
])

此组合器还支持返回多个输出的层,如果作为元组或字典返回。 如果层的输出是一个 tuple,它将在下一层中扩展为 *args,如果它是一个 dict,它将在下一层中扩展为 **kwargs

示例用法

>>> class CrossAttentionBlock(nn.Module):
...   num_heads: int = 2
...   qkv_features: int = 16
...
...   @nn.compact
...   def __call__(self, query, key_value):
...     output = nn.MultiHeadDotProductAttention(
...       num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
...                                                                 key_value)
...     output = nn.Dense(self.qkv_features)(output)
...     return dict(query=output, key_value=key_value)  # also works for tuples

>>> from typing import Sequence
>>> class CrossAttentionNetwork(nn.Module):
...   num_layers: Sequence[int]
...
...   @nn.compact
...   def __call__(self, x):
...     return nn.Sequential([CrossAttentionBlock() for _ in
...                           range(self.num_layers)])(query, key_value)
layers#

要按顺序应用的可调用对象的序列。

类型

collections.abc.Sequence[collections.abc.Callable[[…], Any]]

引发

ValueError – 如果 layers 不是序列。

__call__(*args, **kwargs)[source]#

将 self 作为函数调用。

方法

随机#

class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

创建一个 dropout 层。

注意

使用 Module.apply() 时,请确保包含一个名为 'dropout' 的 RNG 种子。 dropout 对变量初始化并不必要。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class MLP(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(4)(x)
...     x = nn.Dropout(0.5, deterministic=not train)(x)
...     return x

>>> model = MLP()
>>> x = jnp.ones((1, 3))
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
>>> model.apply(variables, x, train=False) # don't use dropout
Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32)
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
Array([[ 0.       , -1.1856356, -1.0369378,  0.       ]], dtype=float32)
rate#

dropout 概率。 (_不是_ 保留率!)

类型

float

broadcast_dims#

将共享相同 dropout 掩码的维度

类型

collections.abc.Sequence[int]

deterministic#

如果为 false,则输入将按 1 / (1 - rate) 缩放并掩盖,而如果为 true,则不应用掩码并按原样返回输入。

类型

bool | None

rng_collection#

请求 rng 键时要使用的 rng 集合名称。

类型

str

__call__(inputs, deterministic=None, rng=None)[source]#

将随机 dropout 掩码应用于输入。

参数
  • inputs – 应该被随机掩盖的输入。

  • deterministic – 如果为 false,则输入将按 1 / (1 - rate) 缩放并掩盖,而如果为 true,则不应用掩码并按原样返回输入。

  • rng – 一个可选的 PRNGKey 用作随机键,如果未指定,则将使用 make_rngrng_collection 名称生成一个。

返回

被掩盖的输入,重新加权以保留均值。

方法

注意力#

class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

多头点积注意力。

示例用法

>>> import flax.linen as nn
>>> import jax

>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
...     num_heads=8,
...     qkv_features=16,
...     kernel_init=nn.initializers.ones,
...     bias_init=nn.initializers.zeros,
...     dropout_rate=0.5,
...     deterministic=False,
...     )
>>> class Module(nn.Module):
...   attention_kwargs: dict
...
...   @nn.compact
...   def __call__(self, x, dropout_rng=None):
...     out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
num_heads#

注意头的数量。 特征(即 inputs_q.shape[-1])应该可以被头数整除。

类型

int

dtype#

计算的 dtype(默认值:从输入和参数推断)

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认值:float32)

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

qkv_features#

键、查询和值的维度。

类型

int | None

out_features#

最后投影的维度

类型

int | None

broadcast_dropout#

使用沿批次维度的广播 dropout。

类型

bool

dropout_rate#

Dropout 率。

类型

float

deterministic#

如果为 False,则注意力权重将使用 dropout 随机掩盖,而如果为 True,则注意力权重是确定性的。

类型

bool | None

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

Dense 层内核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

out_kernel_init#

输出 Dense 层内核的可选初始化器,如果为 None,则将使用 kernel_init

类型

Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]

bias_init#

Dense 层偏差的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

out_bias_init#

输出 Dense 层偏差的可选初始化器,如果为 None,则将使用 bias_init

类型

Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]

use_bias#

点式 QKVO 密集变换是否使用偏差。

类型

bool

attention_fn#

dot_product_attention 或兼容函数。 接受查询、键、值并返回形状为 [bs, dim1, dim2, ..., dimN,, num_heads, value_channels] 的输出

类型

collections.abc.Callable[[…], Union[jax.Array, Any]]

decode#

是否准备和使用自回归缓存。

类型

bool

normalize_qk#

是否应该应用 QK 归一化(arxiv.org/abs/2302.05442)。

类型

bool

qk_attn_weights_einsum_cls#

用于创建计算注意力权重的 einsum 的工厂函数。

类型

collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None

attn_weights_value_einsum_cls#

用于创建计算注意力权重与值的乘积的 einsum 的工厂函数。

类型

collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None

__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#

对输入数据应用多头点积注意力。

将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。

如果 inputs_k 和 inputs_v 都为 None,它们将都复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。

参数
  • inputs_q – 形状为 [batch_sizes..., length, features] 的输入查询。

  • inputs_k – 形状为 [batch_sizes..., length, features] 的键。如果为 None,inputs_k 将复制 inputs_q 的值。

  • inputs_v – 形状为 [batch_sizes..., length, features] 的值。如果为 None,inputs_v 将复制 inputs_k 的值。

  • inputs_kv – 形状为 [batch_sizes..., length, features] 的键/值。如果为 None,inputs_kv 将复制 inputs_q 的值。此参数将很快被弃用。使用 inputs_k 和 inputs_v 代替。

  • mask – 形状为 [batch_sizes..., num_heads, query_length, key/value_length] 的注意力掩码。如果其相应的掩码值为 False,则会屏蔽掉注意力权重。

  • deterministic – 如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。

  • dropout_rng – 传递给注意力层 dropout 掩码的可选 rng 密钥。否则,将使用 self.make_rng(‘dropout’) 代替。

  • sow_weights – 如果为 True,则注意力权重会被播种到 ‘intermediates’ 集合中。请记住,通过 mutable=['intermediates'] 将 ‘intermediates’ 标记为可变,以便返回该集合。

返回

形状为 [batch_sizes..., length, features] 的输出。

方法

class flax.linen.MultiHeadAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

多头点积注意力。MultiHeadDotProductAttention 的别名。

注意MultiHeadAttentionMultiHeadDotProductAttention 的包装器,因此它们的实现相同。但是,默认情况下,MultiHeadAttention 层将被命名为 MultiHeadAttention_{index},而 MultiHeadDotProductAttention 将被命名为 MultiHeadDotProductAttention_{index}。因此,这可能会影响模块内的检查点、参数集合名称和 RNG 线程(因为在生成新的 RNG 时会使用层名称)。

示例用法

>>> import flax.linen as nn
>>> import jax

>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
...     num_heads=8,
...     qkv_features=16,
...     kernel_init=nn.initializers.ones,
...     bias_init=nn.initializers.zeros,
...     dropout_rate=0.5,
...     deterministic=False,
...     )
>>> class Module(nn.Module):
...   attention_kwargs: dict
...
...   @nn.compact
...   def __call__(self, x, dropout_rng=None):
...     out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
num_heads#

注意头的数量。特征(即 inputs_q.shape[-1])应能被头的数量整除。

类型

int

dtype#

计算的 dtype(默认值:从输入和参数推断)

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认值:float32)

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

qkv_features#

键、查询和值的维度。

类型

int | None

out_features#

最后投影的维度

类型

int | None

broadcast_dropout#

bool:沿批次维度使用广播 dropout。

类型

bool

dropout_rate#

dropout 率

类型

float

deterministic#

如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。

类型

bool | None

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

Dense 层内核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

Dense 层偏差的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

use_bias#

bool:点态 QKVO 稠密变换是否使用偏差。

类型

bool

attention_fn#

dot_product_attention 或兼容函数。 接受查询、键、值并返回形状为 [bs, dim1, dim2, ..., dimN,, num_heads, value_channels] 的输出

类型

collections.abc.Callable[[…], Union[jax.Array, Any]]

decode#

是否准备和使用自回归缓存。

类型

bool

normalize_qk#

是否应应用 QK 规范化(arxiv.org/abs/2302.05442)。

类型

bool

__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)#

对输入数据应用多头点积注意力。

将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。

如果 inputs_k 和 inputs_v 都为 None,它们将都复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。

参数
  • inputs_q – 形状为 [batch_sizes..., length, features] 的输入查询。

  • inputs_k – 形状为 [batch_sizes..., length, features] 的键。如果为 None,inputs_k 将复制 inputs_q 的值。

  • inputs_v – 形状为 [batch_sizes..., length, features] 的值。如果为 None,inputs_v 将复制 inputs_k 的值。

  • inputs_kv – 形状为 [batch_sizes..., length, features] 的键/值。如果为 None,inputs_kv 将复制 inputs_q 的值。此参数将很快被弃用。使用 inputs_k 和 inputs_v 代替。

  • mask – 形状为 [batch_sizes..., num_heads, query_length, key/value_length] 的注意力掩码。如果其相应的掩码值为 False,则会屏蔽掉注意力权重。

  • deterministic – 如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。

  • dropout_rng – 传递给注意力层 dropout 掩码的可选 rng 密钥。否则,将使用 self.make_rng(‘dropout’) 代替。

  • sow_weights – 如果为 True,则注意力权重会被播种到 ‘intermediates’ 集合中。请记住,通过 mutable=['intermediates'] 将 ‘intermediates’ 标记为可变,以便返回该集合。

返回

形状为 [batch_sizes..., length, features] 的输出。

方法

class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

多头点积自注意力的特殊情况。该层已被弃用,建议使用 MultiHeadDotProductAttention

用法示例:
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
__call__(inputs_q, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#

在输入数据上应用多头点积自注意力。

将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。

参数
  • inputs_q – 形状为 [batch_sizes..., length, features] 的输入查询。

  • mask – 形状为 [batch_sizes..., num_heads, query_length, key/value_length] 的注意力掩码。如果其相应的掩码值为 False,则会屏蔽掉注意力权重。

  • deterministic – 如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。

返回

形状为 [batch_sizes..., length, features] 的输出。

方法

flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, einsum=None)[source]#

计算给定查询和键的点积注意力权重。

用于 dot_product_attention(),这是您最有可能使用的函数。但是,如果您想访问注意力权重以进行自省,那么您可以直接调用此函数并自行调用 einsum。

参数
  • query – 用于计算注意力的查询,形状为 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用于计算注意力的键,形状为 [batch..., kv_length, num_heads, qk_depth_per_head]

  • bias – 注意力权重的偏差。这应该可以广播到形状 [batch..., num_heads, q_length, kv_length]。这可用于合并因果掩码、填充掩码、邻近偏差等。

  • mask – 注意力权重的掩码。这应该可以广播到形状 [batch..., num_heads, q_length, kv_length]。这可用于合并因果掩码。如果相应的掩码值为 False,则注意力权重会被屏蔽。

  • broadcast_dropout – bool: 在批次维度上使用广播的丢弃。

  • dropout_rng – JAX PRNGKey: 用于丢弃

  • dropout_rate – 丢弃率

  • deterministic – bool,确定性还是非确定性(是否应用丢弃)

  • dtype – 计算的数据类型(默认值:从输入和参数推断)

  • precision – 计算的数值精度,请参阅 jax.lax.Precision 了解详细信息。

  • module – 将注意力权重播种到 ‘intermediates’ 集合中的模块。请记住通过 mutable=['intermediates'] 将 ‘intermediates’ 标记为可变,以便返回该集合。如果 module 为 None,则不会播种注意力权重。

  • force_fp32_for_softmax – bool,是否强制以 fp32 计算 softmax。这对混合精度训练很有用,在混合精度训练中,需要更高的精度才能确保数值稳定性。

  • einsum_dot_general – 在 einsum 中使用的 dot_general。

  • einsum – 如果未指定,将使用默认的 jnp.einsum。此参数与 precisioneinsum_dot_general 互斥。

引发

ValueError – 如果同时指定了 precision/einsum_dot_generaleinsum

返回

形状为 [batch..., num_heads, q_length, kv_length] 的输出。

flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, qk_attn_weights_einsum=None, attn_weights_value_einsum=None)[source]#

计算给定查询、键和值的点积注意力。

这是基于 https://arxiv.org/abs/1706.03762 应用注意力的核心函数。它计算给定查询和键的注意力权重,并使用注意力权重组合值。

注意

querykeyvalue 不需要任何批次维度。

参数
  • query – 用于计算注意力的查询,形状为 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用于计算注意力的键,形状为 [batch..., kv_length, num_heads, qk_depth_per_head]

  • value – 用于注意力的值,形状为 [batch..., kv_length, num_heads, v_depth_per_head]

  • bias – 注意力权重的偏差。这应该可以广播到形状 [batch..., num_heads, q_length, kv_length]。这可用于合并因果掩码、填充掩码、邻近偏差等。

  • mask – 注意力权重的掩码。这应该可以广播到形状 [batch..., num_heads, q_length, kv_length]。这可用于合并因果掩码。如果相应的掩码值为 False,则注意力权重会被屏蔽。

  • broadcast_dropout – bool: 在批次维度上使用广播的丢弃。

  • dropout_rng – JAX PRNGKey: 用于丢弃

  • dropout_rate – 丢弃率

  • deterministic – bool,确定性还是非确定性(是否应用丢弃)

  • dtype – 计算的数据类型(默认值:从输入推断)

  • precision – 计算的数值精度,请参阅 ``jax.lax.Precision` 了解详细信息。

  • module – 将注意力权重播种到 ‘intermediates’ 集合中的模块。请记住通过 mutable=['intermediates'] 将 ‘intermediates’ 标记为可变,以便返回该集合。如果 module 为 None,则不会播种注意力权重。

  • force_fp32_for_softmax – bool,是否强制以 fp32 计算 softmax。这对混合精度训练很有用,在混合精度训练中,需要更高的精度才能确保数值稳定性。

  • einsum_dot_general – 在 jnp.einsum 中使用的 dot_general。

  • qk_attn_weights_einsum – 用于计算注意力权重的 einsum。未指定时,将使用默认的 jnp.einsum。此参数与 precisioneinsum_dot_general 互斥。

  • attn_weights_value_einsum – 用于计算注意力权重和值的乘积的 einsum。未指定时,将使用默认的 jnp.einsum。此参数与 precisioneinsum_dot_general 互斥。

返回

输出形状为 [batch..., q_length, num_heads, v_depth_per_head]

引发
  • ValueError – 如果同时指定了 precision/einsum_dot_general

  • qk_attn_weights_einsum

flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

用于注意力权重的掩码生成助手。

对于一维输入(即,[batch..., len_q][batch..., len_kv]),注意力权重将是 [batch..., heads, len_q, len_kv],此函数将生成 [batch..., 1, len_q, len_kv]

参数
  • query_input – 查询长度大小的批处理扁平输入

  • key_input – 密钥长度大小的批处理扁平输入

  • pairwise_fn – 广播逐元素比较函数

  • extra_batch_dims – 要添加单例轴的额外批次维数,默认情况下为零

  • dtype – 掩码返回的 dtype

返回

用于一维注意力的 [batch..., 1, len_q, len_kv] 形状的掩码。

flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

为自注意力生成因果掩码。

对于一维输入(即,[batch..., len]),自注意力权重将是 [batch..., heads, len, len],此函数将生成形状为 [batch..., 1, len, len] 的因果掩码。

参数
  • x – 形状为 [batch..., len] 的输入数组

  • extra_batch_dims – 要添加单例轴的批次维数,默认情况下为零

  • dtype – 掩码返回的 dtype

返回

用于一维注意力的 [batch..., 1, len, len] 形状的因果掩码。

循环#

class flax.linen.RNNCellBase(parent=<flax.linen.module._Sentinel object>, name=None)[source]#

RNN 单元基类。

__call__(**kwargs)#

将 self 作为函数调用。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元载体。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 提供单元输入形状的元组。

返回

给定 RNN 单元的初始化载体。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元载体。

class flax.linen.LSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

LSTM 单元。

该单元的数学定义如下

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一时间步的输出,c 是内存。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.LSTMCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征数。

类型

int

gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的内核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差参数的初始化器(默认:initializers.zeros_init())

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[source]#

长短期记忆(LSTM)单元。

参数
  • carry – LSTM 单元的隐藏状态,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。

返回

包含新载体和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元载体。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 提供单元输入形状的元组。

返回

给定 RNN 单元的初始化载体。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元载体。

class flax.linen.OptimizedLSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

更高效的 LSTM 单元,它在 matmul 之前连接状态组件。

参数与 LSTMCell 兼容。请注意,只要隐藏大小大致 <= 2048 个单元,此单元通常比 LSTMCell 更快。

该单元的数学定义与 LSTMCell 相同,如下所示

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一时间步的输出,c 是内存。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.OptimizedLSTMCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的内核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差参数的初始化器(默认:initializers.zeros_init())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的 dtype(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[source]#

一个优化的长短期记忆(LSTM)单元。

参数
  • carry – LSTM 单元的隐藏状态,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。

返回

包含新载体和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元载体。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 提供单元输入形状的元组。

返回

给定 RNN 单元的初始化载体。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元载体。

class flax.linen.ConvLSTMCell(features, kernel_size, strides=None, padding='SAME', use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

一个卷积 LSTM 单元。

该实现基于 xingjian2015convolutional。给定 x_t 和先前状态 (h_{t-1}, c_{t-1}),核心计算

\[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]

其中 * 表示卷积运算;i_t、f_t、o_t 是输入、遗忘和输出门激活,g_t 是单元更新的向量。

注意

遗忘门初始化

遵循 jozefowicz2015empirical,我们在初始化后将 1.0 添加到 b_f 以减少训练开始时遗忘的规模。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (3, 5, 5))
>>> layer = nn.ConvLSTMCell(features=4, kernel_size=(2, 2))
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。

类型

collections.abc.Sequence[int]

strides#

一个 n 个整数的序列,表示窗口间步长。

类型

collections.abc.Sequence[int] | None

padding#

字符串 'SAME'、字符串 'VALID'n(low, high) 整数对的序列,表示要应用于每个空间维度的前后的填充。

类型

str | collections.abc.Sequence[tuple[int, int]]

bias#

是否在输出中添加偏差(默认:True)。

dtype#

计算的 dtype(默认:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[source]#

构造一个卷积 LSTM。

参数
  • carry – Conv2DLSTM 单元的隐藏状态,使用 Conv2DLSTM.initialize_carry 初始化。

  • inputs – 输入数据,维度为 (batch, spatial_dims…, features)。

返回

包含新载体和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元载体。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 提供单元输入形状的元组。

返回

给定 RNN 单元的初始化载体。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元载体。

class flax.linen.SimpleCell(features, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, residual=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

简单单元。

该单元的数学定义如下

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]

其中 x 是输入,h 是前一时间步的输出。

如果 residualTrue

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.SimpleCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征数。

类型

int

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的内核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差参数的初始化器(默认:initializers.zeros_init())

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的 dtype(默认:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

residual#

预激活残差连接 (https://arxiv.org/abs/1801.06105).

类型

bool

__call__(carry, inputs)[source]#

简单单元。

参数
  • carry – Simple 单元的隐藏状态,使用 SimpleCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。

返回

包含新载体和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元载体。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 提供单元输入形状的元组。

返回

给定 RNN 单元的初始化载体。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元载体。

class flax.linen.GRUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

GRU 单元。

该单元的数学定义如下

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一时间步的输出。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.GRUCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征数。

类型

int

gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的内核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏差参数的初始化器(默认:initializers.zeros_init())

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的 dtype(默认:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[source]#

门控循环单元 (GRU) 单元。

参数
  • carry – GRU 单元的隐藏状态,使用 GRUCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。

返回

包含新载体和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元载体。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 提供单元输入形状的元组。

返回

给定 RNN 单元的初始化载体。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元载体。

class flax.linen.MGUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, forget_bias_init=<function ones>, activation_bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, reset_gate=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

MGU 单元 (https://arxiv.org/pdf/1603.09420.pdf).

该单元的数学定义如下

\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + f * (W_{hn} h + b_{hn})) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一时间步的输出。

如果 reset_gate 为假,则上述公式变为

\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + W_{hn} h) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.MGUCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征数。

类型

int

gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的内核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

forget_bias_init#

遗忘门偏差参数的初始化器。默认设置为 initializers.ones_init(),因为这可以防止梯度消失。有关更多详细信息,请参见 https://proceedings.mlr.press/v37/jozefowicz15.pdf,第 2.2 节。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

activation_bias_init#

激活输出的偏差参数的初始化器(默认:initializers.zeros_init())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的 dtype(默认:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

reset_gate#

应用重置门的标志。

类型

bool

__call__(carry, inputs)[source]#

最小门控单元 (MGU) 单元。

参数
  • carry – MGU 单元的隐藏状态,使用 MGUCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。

返回

包含新载体和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元载体。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 提供单元输入形状的元组。

返回

给定 RNN 单元的初始化载体。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元载体。

class flax.linen.RNN(cell, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, variable_axes=FrozenDict({}), variable_broadcast='params', variable_carry=False, split_rngs=FrozenDict({     params: False, }), parent=<flax.linen.module._Sentinel object>, name=None)[source]#

The RNN module takes any RNNCellBase instance and applies it over a sequence

using flax.linen.scan().

示例

>>> import jax.numpy as jnp
>>> import jax
>>> import flax.linen as nn

>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64))
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

如上所示,RNN 使用 cell_size 参数为单元的 initialize_carry 方法设置 size 参数,实际上,这通常是你想要的单元的隐藏单元数量。但是,这可能根据你使用的单元而有所不同,例如 ConvLSTMCell 需要 size 参数,形式为 (kernel_height, kernel_width, features)

>>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features)
>>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3)))
>>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x)
>>> y.shape # (batch, time, height, width, features)
(10, 50, 32, 32, 64)

默认情况下,RNN 期望时间维度位于批次维度之后 ((*batch, time, *features)),如果设置 time_major=True,则 RNN 将期望时间维度位于开头 ((time, *batch, *features))

>>> x = jnp.ones((50, 10, 32)) # (time, batch, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (time, batch, cell_size)
(50, 10, 64)

输出是一个形状为 (*batch, time, *cell_size) 的数组(通常),但是如果设置 return_carry=True,则它将返回最终的 carry 和输出的元组。

>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> carry, y = lstm.apply(variables, x)
>>> jax.tree_util.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size))
((10, 64), (10, 64))
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

为了支持可变长度序列,你可以传递一个 seq_lengths,它是一个形状为 (*batch) 的整数数组,其中每个元素都是批次中序列的长度。例如

>>> seq_lengths = jnp.array([3, 2, 5])

对应于填充元素的输出元素不会被清零。如果将 return_carry 设置为 True,则 carry 将是每个序列的最后一个有效元素的状态。

RNN 还接受 flax.linen.scan() 的一些参数,默认情况下它们被设置为与像 LSTMCellGRUCell 这样的单元一起使用,但可以根据需要覆盖它们。覆盖扫描的默认值如下所示

>>> lstm = nn.RNN(
...   nn.LSTMCell(64),
...   unroll=1, variable_axes={}, variable_broadcast='params',
...   variable_carry=False, split_rngs={'params': False})
cell#

一个 RNNCellBase 的实例。

类型

flax.linen.recurrent.RNNCellBase

time_major#

如果 time_major=False(默认)它将期望输入的形状为 (*batch, time, *features),否则它将期望输入的形状为 (time, *batch, *features)

类型

bool

return_carry#

如果 return_carry=False(默认)只返回输出序列,否则将返回最终承载和输出序列的元组。

类型

bool

reverse#

如果 reverse=False(默认)序列从左到右处理并以原始顺序返回,否则它将从右到左处理,并以相反顺序返回。如果传递了 seq_lengths,填充将始终保留在序列的末尾。

类型

bool

keep_order#

如果 keep_order=True,当 reverse=True 时,输出将在处理后被反转回原始顺序,这对于对齐双向 RNN 中的序列很有用。如果 keep_order=False(默认),输出将保留在由 reverse 指定的顺序。

类型

bool

unroll#

在循环的单次迭代中展开多少次扫描迭代,默认为 1。此参数将传递给 nn.scan

类型

int

variable_axes#

一个字典,将每个集合映射到一个整数 i(表示我们扫描维度 i)或 None(复制而不是扫描)。此参数将转发到 nn.scan

类型

collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.typing.In[int], flax.typing.Out[int]]]

variable_broadcast#

指定广播的变量集合。广播变量不应依赖于任何无法从循环中提取的计算。这通常用于在 fn 内定义共享参数。此参数将转发到 nn.scan

类型

Union[bool, str, Collection[str], DenyList]

variable_carry#

指定在循环中传递的变量集合。对这些变量的修改将被传递到下一轮迭代,并且在扫描结束时将被保留。此参数将转发到 nn.scan

类型

Union[bool, str, Collection[str], DenyList]

split_rngs#

一个映射,从 PRNGSequenceFilter 到 bool,指定一个集合的 PRNG 密钥是否应该被拆分,以便其值在每个步骤中都不同,或者被复制,以便其值在每个步骤中都保持相同。此参数将转发到 nn.scan

类型

collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], bool]

__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

将 RNN 应用于输入。

__call__ 允许您选择性地覆盖某些属性,例如 return_carrytime_major,这些属性是在构造函数中定义的。

参数
  • inputs – 输入序列。

  • initial_carry – 初始承载,如果未提供,它将使用单元的 RNNCellBase.initialize_carry() 方法初始化。

  • init_key – 用于初始化承载的 PRNG 密钥,如果未提供,将使用 jax.random.key(0)。大多数单元将忽略此参数。

  • seq_lengths – 一个可选的形状为 (*batch) 的整数数组,指示每个序列的长度,时间维度中索引大于相应长度的元素将被视为填充并被忽略。

  • return_carry – 如果 return_carry=False(默认)只返回输出序列,否则将返回最终承载和输出序列的元组。

  • time_major – 如果 time_major=False(默认)它将期望输入的形状为 (*batch, time, *features),否则它将期望输入的形状为 (time, *batch, *features)

  • reverse – 覆盖 reverse 属性,如果 reverse=False(默认)序列从左到右处理并以原始顺序返回,否则它将从右到左处理,并以相反顺序返回。如果传递了 seq_lengths,填充将始终保留在序列的末尾。

  • keep_order – 覆盖 keep_order 属性,如果 keep_order=True,当 reverse=True 时,输出将在处理后被反转回原始顺序,这对于对齐双向 RNN 中的序列很有用。如果 keep_order=False(默认),输出将保留在由 reverse 指定的顺序。

返回

如果 return_carry=False(默认)只返回输出序列,否则将返回最终承载和输出序列的元组。

方法

class flax.linen.Bidirectional(forward_rnn, backward_rnn, merge_fn=<function _concatenate>, time_major=False, return_carry=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

双向处理输入并合并结果。

示例用法

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Bidirectional(nn.RNN(nn.GRUCell(4)), nn.RNN(nn.GRUCell(4)))
>>> x = jnp.ones((2, 3))
>>> variables = layer.init(jax.random.key(0), x)
>>> out = layer.apply(variables, x)
__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

将 self 作为函数调用。

方法

BatchApply#

class flax.linen.BatchApply(f, num_dims=2)[source]#

临时合并输入张量的领先维度。

将张量的领先维度合并为单个维度,运行给定的可调用对象,然后拆分结果的领先维度以匹配输入。

排名小于要折叠的维度数量的输入数组将被原样传递。

这对于将模块应用于例如 [Time, Batch, ...] 数组的每个时间步可能很有用。

对于某些 f 和平台,这可能比 jax.vmap() 更有效,尤其是在与其他转换(如 jax.grad())结合使用时。

示例用法

>>> import jax, jax.numpy as jnp

>>> a = jax.random.normal(jax.random.key(0), [2, 3, 4])
>>> b = jax.random.normal(jax.random.key(1), [4])

>>> def raises(a, b):
...   if len(a.shape) != 2:
...     raise ValueError("a must be shape 2")
...   if len(b.shape) != 1:
...     raise ValueError("b must be shape 1")
...   return jnp.dot(a, b)

>>> out = BatchApply(raises)(a, b)
>>> expected_merged_leading = raises(a.reshape(2*3, 4), b)
>>> expected = expected_merged_leading.reshape((2, 3) + expected_merged_leading.shape[1:])
>>> np.testing.assert_array_equal(out, expected)
__call__(*args, **kwargs)[source]#

将 self 作为函数调用。

方法