层#
线性模块#
- class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
应用于输入最后一维度的线性变换。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Dense(features=4) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4,), 'kernel': (3, 4)}}
- features#
输出特征的数量。
- 类型
int
- use_bias#
是否在输出中添加偏差(默认:True)。
- 类型
bool
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
权重矩阵的初始化函数。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差的初始化函数。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
方法
- class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
具有灵活轴的线性变换。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # equivalent to `nn.Dense(features=4)` >>> layer = nn.DenseGeneral(features=4) >>> # output features (4, 5) >>> layer = nn.DenseGeneral(features=(4, 5)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}} >>> # apply transformation on the the second and last axes >>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
- features#
输出特征数量的整型或元组。
- 类型
int | collections.abc.Sequence[int]
- axis#
要应用变换的轴的整型或元组。例如,(-2, -1) 将把变换应用于最后两个轴。
- 类型
int | collections.abc.Sequence[int]
- batch_dims#
包含批次轴的元组。
- 类型
collections.abc.Sequence[int]
- use_bias#
是否在输出中添加偏差(默认:True)。
- 类型
bool
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- kernel_init#
权重矩阵的初始化函数。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差的初始化函数。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
方法
- class flax.linen.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
包装 `lax.conv_general_dilated` 的卷积模块。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
- features#
卷积滤波器的数量。
- 类型
int
- kernel_size#
卷积核的形状。单个整数将被解释为单个整数的元组。
- 类型
int | collections.abc.Sequence[int]
- strides#
一个整数或一个 `n` 个整数的序列,表示窗口间步长(默认值:1)。
- 类型
None | int | collections.abc.Sequence[int]
- padding#
字符串 `'SAME'`,字符串 `'VALID'`,字符串 `'CIRCULAR'`(周期性边界条件)之一,或者一个 `n` 个 `(low, high)` 整数对的序列,这些整数对给出要对每个空间维度应用的填充。单个 int 被解释为在所有维度上应用相同的填充,并在序列中分配单个 int 会导致在两侧使用相同的填充。对于一维卷积,`'CAUSAL'` 填充将对卷积轴进行左填充,从而产生相同大小的输出。
- 类型
Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]
- input_dilation#
一个整数或一个 `n` 个整数的序列,给出要对 `inputs` 的每个空间维度应用的膨胀因子(默认值:1)。具有输入膨胀 `d` 的卷积等效于具有步长 `d` 的转置卷积。
- 类型
None | int | collections.abc.Sequence[int]
- kernel_dilation#
一个整数或一个 `n` 个整数的序列,给出要对卷积核的每个空间维度应用的膨胀因子(默认值:1)。具有核膨胀的卷积也称为“空洞卷积”。
- 类型
None | int | collections.abc.Sequence[int]
- feature_group_count#
整数,默认值为 1。如果指定,则将输入特征划分为组。
- 类型
int
- use_bias#
是否在输出中添加偏差(默认:True)。
- 类型
bool
- mask#
在掩码卷积期间,权重的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。
- 类型
Optional[Union[jax.Array, Any]]
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- precision#
计算的数值精度,请参阅 ``jax.lax.Precision` 获取详细信息。
- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
卷积核的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- __call__(inputs)#
对输入应用(可能不共享的)卷积。
- 参数
inputs – 具有维度 `(*batch_dims, spatial_dims..., features)` 的输入数据。这是 channels-last 约定,即二维卷积的 NHWC 和三维卷积的 NDHWC。注意:这与 `lax.conv_general_dilated` 使用的输入约定不同,它将空间维度放在最后。注意:如果输入具有多个批次维度,所有批次维度都将被展平成单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 层可能比这种默认展平方法产生更好的性能。如果输入缺少批次维度,它将被添加到卷积中,并在返回时删除,这是一个允许编写单示例代码的允许值。
- 返回
卷积后的数据。
方法
- class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
封装了
lax.conv_transpose
的卷积模块。示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True) >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}} >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
- features#
卷积滤波器的数量。
- 类型
int
- kernel_size#
卷积核的形状。对于一维卷积,核大小可以作为整数传递,将被解释为单个整数的元组。对于所有其他情况,它必须是整数序列。
- 类型
int | collections.abc.Sequence[int]
- strides#
一个整数或一个 n 个整数的序列,表示窗口间步长。
- 类型
collections.abc.Sequence[int] | None
- padding#
字符串 ‘SAME’、字符串 ‘VALID’、字符串 ‘CIRCULAR’(周期性边界条件)或 n 个 (low, high) 整数对的序列,这些对给出要应用于每个空间维度的填充。单个 int 被解释为在所有维度中应用相同的填充,并在序列中分配单个 int 导致在两侧使用相同的填充。
- 类型
Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]
- kernel_dilation#
None
或一个整数或n
个整数的序列,给出要应用于卷积核每个空间维度的膨胀因子。具有内核膨胀的卷积也被称为“空洞卷积”。- 类型
collections.abc.Sequence[int] | None
- use_bias#
是否在输出中添加偏差(默认:True)。
- 类型
bool
- mask#
在掩码卷积期间,权重的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。
- 类型
Optional[Union[jax.Array, Any]]
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
卷积核的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- transpose_kernel#
如果
True
翻转空间轴并交换内核的输入/输出通道轴。- 类型
bool
- __call__(inputs)[source]#
对输入应用转置卷积。
行为镜像
jax.lax.conv_transpose
。- 参数
inputs – 输入数据,维度为
(*batch_dims, spatial_dims..., features).
这是通道最后约定,即二维卷积的 NHWC 和三维卷积的 NDHWC。注意:这与lax.conv_general_dilated
使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度将被展平为单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接对层进行 vmap 可能比这种默认展平方法产生更好的性能。如果输入缺少批次维度,则将为卷积添加它,并在返回时删除,允许编写单个示例代码。- 返回
卷积后的数据。
方法
- class flax.linen.ConvLocal(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
封装了
lax.conv_general_dilated_local
的局部卷积模块。示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((6, 9, 4))) >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
- features#
卷积滤波器的数量。
- 类型
int
- kernel_size#
卷积核的形状。单个整数将被解释为单个整数的元组。
- 类型
int | collections.abc.Sequence[int]
- strides#
一个整数或一个 `n` 个整数的序列,表示窗口间步长(默认值:1)。
- 类型
None | int | collections.abc.Sequence[int]
- padding#
字符串 `'SAME'`,字符串 `'VALID'`,字符串 `'CIRCULAR'`(周期性边界条件)之一,或者一个 `n` 个 `(low, high)` 整数对的序列,这些整数对给出要对每个空间维度应用的填充。单个 int 被解释为在所有维度上应用相同的填充,并在序列中分配单个 int 会导致在两侧使用相同的填充。对于一维卷积,`'CAUSAL'` 填充将对卷积轴进行左填充,从而产生相同大小的输出。
- 类型
Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]
- input_dilation#
一个整数或一个 `n` 个整数的序列,给出要对 `inputs` 的每个空间维度应用的膨胀因子(默认值:1)。具有输入膨胀 `d` 的卷积等效于具有步长 `d` 的转置卷积。
- 类型
None | int | collections.abc.Sequence[int]
- kernel_dilation#
一个整数或一个 `n` 个整数的序列,给出要对卷积核的每个空间维度应用的膨胀因子(默认值:1)。具有核膨胀的卷积也称为“空洞卷积”。
- 类型
None | int | collections.abc.Sequence[int]
- feature_group_count#
整数,默认值为 1。如果指定,则将输入特征划分为组。
- 类型
int
- use_bias#
是否在输出中添加偏差(默认:True)。
- 类型
bool
- mask#
在掩码卷积期间,权重的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。
- 类型
Optional[Union[jax.Array, Any]]
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
卷积核的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- __call__(inputs)#
对输入应用(可能不共享的)卷积。
- 参数
inputs – 具有维度 `(*batch_dims, spatial_dims..., features)` 的输入数据。这是 channels-last 约定,即二维卷积的 NHWC 和三维卷积的 NDHWC。注意:这与 `lax.conv_general_dilated` 使用的输入约定不同,它将空间维度放在最后。注意:如果输入具有多个批次维度,所有批次维度都将被展平成单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 层可能比这种默认展平方法产生更好的性能。如果输入缺少批次维度,它将被添加到卷积中,并在返回时删除,这是一个允许编写单示例代码的允许值。
- 返回
卷积后的数据。
方法
- class flax.linen.Einsum(shape, einsum_str=None, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
具有可学习内核和偏差的 einsum 变换。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Einsum((5, 6, 7), 'abc,cde->abde') >>> variables = layer.init(jax.random.key(0), jnp.ones((3, 4, 5))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (6, 7), 'kernel': (5, 6, 7)}}
- shape#
内核的形状。
- 类型
collections.abc.Sequence[int]
- einsum_str#
用于表示 einsum 方程的字符串。该方程必须恰好有两个操作数,lhs 是传递的输入,rhs 是可学习的内核。构造函数参数和调用参数中的
einsum_str
必须有一个不为 None,而另一个必须为 None。- 类型
str | None
- use_bias#
是否在输出中添加偏差(默认:True)。
- 类型
bool
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
权重矩阵的初始化函数。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差的初始化函数。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- __call__(inputs, einsum_str=None)[source]#
对输入沿着最后一个维度应用线性变换。
- 参数
inputs – 要变换的 nd 数组。
einsum_str – 表示 einsum 方程式的字符串。该方程式必须正好有两个操作数,左侧是传入的输入,右侧是可学习的核。调用方法中传入的
einsum_str
将优先于构造函数中传入的einsum_str
。
- 返回
变换后的输入。
方法
- class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
嵌入模块。
从整数 [0,
num_embeddings
) 到features
维向量的一个参数化函数。此Module
将创建一个形状为(num_embeddings, features)
的embedding
矩阵。当调用此层时,输入值将用于 0 索引到embedding
矩阵中。对大于或等于num_embeddings
的值的索引将导致nan
值。当num_embeddings
等于 1 时,它将广播embedding
矩阵到输入形状,并附加features
维度。示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Embed(num_embeddings=5, features=3) >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> variables = layer.init(jax.random.key(0), indices_input) >>> variables {'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ], [-0.11768015, -0.54618824, -0.3789283 ], [ 0.30428642, 0.49511626, 0.01706631], [-0.0982546 , -0.43055868, 0.20654906], [-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}} >>> # get the first three and last three embeddings >>> layer.apply(variables, indices_input) Array([[[-0.28884724, 0.19018005, -0.414205 ], [-0.11768015, -0.54618824, -0.3789283 ], [ 0.30428642, 0.49511626, 0.01706631]], [[-0.688412 , -0.46882293, 0.26723292], [-0.0982546 , -0.43055868, 0.20654906], [ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32)
- num_embeddings#
嵌入数量 / 词汇量大小。
- 类型
int
- features#
每个嵌入的特征维数。
- 类型
int
- dtype#
嵌入向量的 dtype(默认值:与嵌入相同)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- embedding_init#
嵌入初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- __call__(inputs)[source]#
沿最后一个维度嵌入输入。
- 参数
inputs – 输入数据,所有维度都被视为批处理维度。输入数组中的值必须是整数。
- 返回
输出是嵌入的输入数据。输出形状遵循输入,并在其后附加一个
features
维度。
- attend(query)[source]#
使用查询数组对嵌入进行关注。
- 参数
query – 数组,其最后一个维度等于嵌入的特征深度
features
。- 返回
一个数组,其最后一个维度为
num_embeddings
,对应于查询向量数组与每个嵌入的批处理内积。通常用于 NLP 模型中嵌入和 logits 变换之间的权重共享。
方法
attend
(query)使用查询数组对嵌入进行关注。
池化#
- flax.linen.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#
通过对窗口切片取最大值来池化输入。
- 参数
inputs – 输入数据,维度为 (batch, window dims…, features)。
window_shape – 一个形状元组,定义要减少的窗口。
strides – 一系列
n
个整数,表示窗口间步长(默认值:(1, ..., 1)
)。padding – 既可以是字符串
'SAME'
,也可以是字符串'VALID'
,或者是一系列n
个(low, high)
整数对,用于给出要应用于每个空间维度的前后的填充(默认值:'VALID'
)。
- 返回
每个窗口切片的最大值。
- flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#
通过对窗口取平均值来池化输入。
- 参数
inputs – 输入数据,维度为 (batch, window dims…, features)。
window_shape – 一个形状元组,定义要减少的窗口。
strides – 一系列
n
个整数,表示窗口间步长(默认值:(1, ..., 1)
)。padding – 既可以是字符串
'SAME'
,也可以是字符串'VALID'
,或者是一系列n
个(low, high)
整数对,用于给出要应用于每个空间维度的前后的填充(默认值:'VALID'
)。count_include_pad – 一个布尔值,表示是否将填充的标记包含在平均值计算中(默认值:
True
)。
- 返回
每个窗口切片的平均值。
- flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#
用于定义池化函数的辅助函数。
池化函数是使用 ReduceWindow XLA 操作实现的。
注意
请注意,池化通常不可微。这意味着提供一个可微的 reduce_fn 并不意味着 pool 可微。
- 参数
inputs – 输入数据,维度为 (batch, window dims…, features)。
init – 减少的初始值
reduce_fn – 形式为
(T, T) -> T
的减少函数。window_shape – 一个形状元组,定义要减少的窗口。
strides – 一系列
n
个整数,表示窗口间步长(默认值:(1, ..., 1)
)。padding – 既可以是字符串
'SAME'
,也可以是字符串'VALID'
,或者是一系列n
个(low, high)
整数对,用于给出要应用于每个空间维度的前后的填充。
- 返回
每个窗口切片的减少输出。
规范化#
- class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
BatchNorm 模块。
使用说明:如果我们定义一个带 BatchNorm 的模型,例如
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> BN = nn.BatchNorm(momentum=0.9, epsilon=1e-5, dtype=jnp.float32)
初始化的变量字典除了包含 ‘params’ 集合外,还会包含一个单独的 ‘batch_stats’ 集合,该集合将包含模型中所有 BatchNorm 层的所有运行统计数据
>>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> variables = BN.init(jax.random.key(1), x, use_running_average=False) >>> jax.tree_util.tree_map(jnp.shape, variables) {'batch_stats': {'mean': (6,), 'var': (6,)}, 'params': {'bias': (6,), 'scale': (6,)}}
然后,我们在训练期间通过指定
batch_stats
集合在模块的apply
方法中是可变的来更新 batch_stats。>>> y, new_batch_stats = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=False)
在评估过程中,我们将使用
use_running_average=True
定义 BN,并使用来自训练的 batch_stats 集合来设置统计信息。在这种情况下,我们不会修改 batch statistics 集合,因此不需要将其标记为可变的。>>> y = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=True)
- use_running_average#
如果为 True,则将使用 batch_stats 中存储的统计信息,而不是计算输入的批次统计信息。
- 类型
bool | None
- axis#
输入的特征或非批次轴。
- 类型
int
- momentum#
批次统计信息的指数移动平均的衰减率。
- 类型
float
- epsilon#
添加到方差中的一个小的浮点数,以避免除以零。
- 类型
float
- dtype#
结果的 dtype(默认值:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- use_bias#
如果为 True,则添加偏差(beta)。
- 类型
bool
- use_scale#
如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。
- 类型
bool
- bias_init#
偏差的初始化器,默认值为零。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- scale_init#
比例的初始化器,默认值为一。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- axis_name#
用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅
jax.pmap
(默认值:None)。请注意,这仅用于 pmap 和 shard map。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。- 类型
str | None
- axis_index_groups#
在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅jax.lax.psum
。- 类型
Any
- use_fast_variance#
如果为 True,则使用更快但数值稳定性较差的方差计算方法。
- 类型
bool
- __call__(x, use_running_average=None, *, mask=None)[source]#
使用批次统计信息对输入进行归一化。
注意
在初始化期间(当
self.is_initializing()
为True
时),批次统计信息的运行平均值不会更新。因此,在初始化期间馈送的输入不需要与实际输入分布相匹配,并且约简轴(使用axis_name
设置)不需要存在。- 参数
x – 要归一化的输入。
use_running_average – 如果为 true,则将使用 batch_stats 中存储的统计信息,而不是计算输入的批次统计信息。
mask – 形状可广播到
inputs
张量的二元数组,指示应计算均值和方差的位置。
- 返回
归一化的输入(与输入相同的形状)。
方法
- class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
层归一化 (https://arxiv.org/abs/1607.06450).
LayerNorm 对批次中每个给定示例的层激活进行归一化,而不是像批次归一化一样跨批次进行归一化。即应用变换,使每个示例内的平均激活保持接近 0,并且激活标准差保持接近 1。
注意
此归一化操作与 InstanceNorm 和 GroupNorm 相同;区别仅在于哪些轴被约简以及特征轴的形状(即可学习比例和偏差参数的形状)。
示例用法
>>> import flax.linen as nn >>> import jax >>> import numpy as np >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nn.LayerNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> y = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x) >>> y2 = nn.GroupNorm(num_groups=1).apply(variables, x) >>> np.testing.assert_allclose(y, y2) >>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x) >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2)
- epsilon#
添加到方差中的一个小的浮点数,以避免除以零。
- 类型
float
- dtype#
结果的 dtype(默认值:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- use_bias#
如果为 True,则添加偏差(beta)。
- 类型
bool
- use_scale#
如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。
- 类型
bool
- bias_init#
偏差的初始化器,默认值为零。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- scale_init#
比例的初始化器,默认值为一。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- reduction_axes#
用于计算归一化统计信息的轴。
- 类型
Union[int, collections.abc.Sequence[int]]
- feature_axes#
用于学习偏差和缩放的特征轴。
- 类型
Union[int, collections.abc.Sequence[int]]
- axis_name#
用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅
jax.pmap
(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。- 类型
str | None
- axis_index_groups#
在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅jax.lax.psum
。- 类型
Any
- use_fast_variance#
如果为 True,则使用更快但数值稳定性较差的方差计算方法。
- 类型
bool
- __call__(x, *, mask=None)[source]#
对输入应用层归一化。
- 参数
x – 输入
mask – 形状可广播到
inputs
张量的二元数组,指示应计算均值和方差的位置。
- 返回
归一化的输入(与输入相同的形状)。
方法
- class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
组归一化 (arxiv.org/abs/1803.08494).
该操作类似于批归一化,但统计信息在大小相等的通道组之间共享,而不是在批次维度之间共享。因此,组归一化不依赖于批次组成,也不需要维护内部状态来存储统计信息。用户应该指定通道组的总数或每个组的通道数。
注意
LayerNorm 是 GroupNorm 的特例,其中
num_groups=1
,而 InstanceNorm 是 GroupNorm 的特例,其中group_size=1
。示例用法
>>> import flax.linen as nn >>> import jax >>> import numpy as np >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nn.GroupNorm(num_groups=3) >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> y = nn.GroupNorm(num_groups=1).apply(variables, x) >>> y2 = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x) >>> np.testing.assert_allclose(y, y2) >>> y = nn.GroupNorm(num_groups=None, group_size=1).apply(variables, x) >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2)
- num_groups#
通道组的总数。原始组归一化论文建议默认值为 32。
- 类型
int | None
- group_size#
每个组的通道数。
- 类型
int | None
- epsilon#
添加到方差中的一个小的浮点数,以避免除以零。
- 类型
float
- dtype#
结果的 dtype(默认值:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- use_bias#
如果为 True,则添加偏差(beta)。
- 类型
bool
- use_scale#
如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。
- 类型
bool
- bias_init#
偏差的初始化器,默认值为零。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- scale_init#
比例的初始化器,默认值为一。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- reduction_axes#
用于计算归一化统计信息的轴列表。此列表必须包含最后一个维度,该维度被假定为特征轴。此外,如果调用时使用的输入与用于初始化的数据相比具有额外的领先轴,例如由于批处理,则需要明确定义归约轴。
- 类型
Optional[Union[int, collections.abc.Sequence[int]]]
- axis_name#
用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅
jax.pmap
(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。- 类型
str | None
- axis_index_groups#
在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅jax.lax.psum
。- 类型
Any
- use_fast_variance#
如果为 True,则使用更快但数值稳定性较差的方差计算方法。
- 类型
bool
- __call__(x, *, mask=None)[source]#
将组归一化应用于输入 (arxiv.org/abs/1803.08494).
- 参数
x – 形状为
...C
的输入,其中C
是通道维度,而...
表示任意数量的额外维度,这些维度可用于在统计信息上累积。如果没有指定归约轴,则除了假定代表批次的第一个维度之外,所有额外维度...
将用于累积统计信息。mask – 形状可广播到
inputs
张量的二元数组,指示应计算均值和方差的位置。
- 返回
归一化的输入(与输入相同的形状)。
方法
- class flax.linen.RMSNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
RMS 层归一化 (https://arxiv.org/abs/1910.07467).
RMSNorm 独立地为批次中的每个示例规范化层的激活,而不是像批次归一化一样跨批次规范化。与 LayerNorm 将平均值重新居中为 0 并通过激活的标准差进行规范化不同,RMSNorm 根本不重新居中,而是通过激活的均方根进行规范化。
示例用法
>>> import flax.linen as nn >>> import jax >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nn.RMSNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}} >>> y = layer.apply(variables, x)
- epsilon#
添加到方差中的一个小的浮点数,以避免除以零。
- 类型
float
- dtype#
结果的 dtype(默认值:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- use_scale#
如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。
- 类型
bool
- scale_init#
比例的初始化器,默认值为一。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- reduction_axes#
用于计算归一化统计信息的轴。
- 类型
Union[int, collections.abc.Sequence[int]]
- feature_axes#
用于学习偏差和缩放的特征轴。
- 类型
Union[int, collections.abc.Sequence[int]]
- axis_name#
用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅
jax.pmap
(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。- 类型
str | None
- axis_index_groups#
在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅jax.lax.psum
。- 类型
Any
- use_fast_variance#
如果为 True,则使用更快但数值稳定性较差的方差计算方法。
- 类型
bool
- __call__(x, *, mask=None)[source]#
对输入应用 RMS 层归一化。
- 参数
x – 输入
mask – 形状可广播到
inputs
张量的二元数组,指示应计算均值和方差的位置。
- 返回
归一化的输入(与输入相同的形状)。
方法
- class flax.linen.InstanceNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
实例归一化 (https://arxiv.org/abs/1607.08022v3).
InstanceNorm 为每个通道(而不是像 Layer Normalization 那样跨所有通道)以及独立地为批次中的每个示例(而不是像批次归一化那样跨整个批次)规范化层的激活。即,应用一个变换,使每个示例内每个通道内的平均激活值接近 0,而激活标准差接近 1。
注意
此归一化操作与 LayerNorm 和 GroupNorm 相同;唯一的区别在于归约哪些轴以及特征轴的形状(即可学习缩放和偏差参数的形状)。
示例用法
>>> import flax.linen as nn >>> import jax >>> import numpy as np >>> # dimensions: (batch, height, width, channel) >>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) >>> layer = nn.InstanceNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch, >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm >>> y2 = nn.LayerNorm(reduction_axes=[1, 2], feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2, atol=1e-7) >>> y3 = nn.GroupNorm(num_groups=x.shape[-1]).apply(variables, x) >>> np.testing.assert_allclose(y, y3, atol=1e-7)
- epsilon#
添加到方差中的一个小的浮点数,以避免除以零。
- 类型
float
- dtype#
结果的 dtype(默认值:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- use_bias#
如果为 True,则添加偏差(beta)。
- 类型
bool
- use_scale#
如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用它,因为缩放将由下一层完成。
- 类型
bool
- bias_init#
偏差的初始化器,默认值为零。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- scale_init#
比例的初始化器,默认值为一。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- feature_axes#
特征轴。学习到的偏置和缩放参数将具有特征轴定义的形状。除了批处理轴(假设为前导轴)之外的所有其他轴都将被减少。
- 类型
Union[int, collections.abc.Sequence[int]]
- axis_name#
用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参阅
jax.pmap
(默认值:None)。这仅在模型在设备之间进行细分时才需要,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您不需要手动同步。只需确保正确注释轴,XLA:SPMD 将插入必要的集体。- 类型
str | None
- axis_index_groups#
在该命名轴内表示要减少的设备子集的轴索引组(默认值:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和后两个设备上的示例进行批次归一化。有关详细信息,请参阅jax.lax.psum
。- 类型
Any
- use_fast_variance#
如果为 True,则使用更快但数值稳定性较差的方差计算方法。
- 类型
bool
- __call__(x, *, mask=None)[source]#
对输入应用实例归一化。
- 参数
x – 输入
mask – 形状可广播到
inputs
张量的二元数组,指示应计算均值和方差的位置。
- 返回
归一化的输入(与输入相同的形状)。
方法
- class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
谱归一化。
见
谱归一化将权重参数归一化,以便矩阵的谱范数等于 1。这实现为一个层包装器,其中每个包装层在其计算其
__call__
输出之前会对其参数进行谱归一化。注意
初始化的变量字典除了包含一个“params”集合之外,还包含一个单独的“batch_stats”集合,该集合将包含一个
u
向量和sigma
值,这些值是执行谱归一化时使用的中间值。在训练期间,我们传入update_stats=True
和mutable=['batch_stats']
,以便使用幂迭代方法计算的最新值更新u
和sigma
。这将有助于幂迭代方法随着时间的推移更准确地逼近真实的奇异值。在评估期间,我们传入update_stats=False
以确保我们从模型中获得确定性行为。示例用法
>>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> import optax >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(3)(x) ... # only spectral normalize the params of the second Dense layer ... x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train) ... x = nn.Dense(5)(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 5)) >>> model = Foo() >>> variables = model.init(jax.random.PRNGKey(0), x, train=False) >>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables)) FrozenDict({ batch_stats: { SpectralNorm_0: { Dense_1/kernel/sigma: (), Dense_1/kernel/u: (1, 4), }, }, params: { Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, }, }) >>> # train >>> def train_step(variables, x, y): ... def loss_fn(params): ... logits, updates = model.apply( ... {'params': params, 'batch_stats': variables['batch_stats']}, ... x, ... train=True, ... mutable=['batch_stats'], ... ) ... loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y)) ... return loss, updates ... ... (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)( ... variables['params'] ... ) ... return { ... 'params': jax.tree_util.tree_map( ... lambda p, g: p - 0.1 * g, variables['params'], grads ... ), ... 'batch_stats': updates['batch_stats'], ... }, loss >>> for _ in range(10): ... variables, loss = train_step(variables, x, y) >>> # inference / eval >>> out = model.apply(variables, x, train=False)
- layer_instance#
用 SpectralNorm 包装的模块实例
- n_steps#
执行多少步幂迭代以逼近权重参数的奇异值。
- 类型
int
- epsilon#
添加到 l2 规范化中的一个小浮点数,以避免除以零。
- 类型
float
- dtype#
结果的 dtype(默认值:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- error_on_non_matrix#
谱归一化仅针对矩阵定义。默认情况下,此模块将返回未改变的标量并将更高阶张量在其前导维度中展平。将此标志设置为 True 将在层使用维度大于 2 的权重张量时抛出错误。
- 类型
bool
- collection_name#
存储执行谱归一化时使用的中间值的集合的名称。
- 类型
str
- __call__(*args, update_stats, **kwargs)[source]#
使用幂迭代方法计算
self.layer_instance
中权重的最大奇异值,并使用此值归一化权重,然后计算__call__
输出。- 参数
*args – 传递给
self.layer_instance
中底层层实例的调用方法的位置参数。update_stats – 如果为 True,则在使用幂迭代方法计算其更新的值后更新内部
u
向量和sigma
值。这将有助于幂迭代方法随着时间的推移更准确地逼近真实的奇异值。**kwargs – 传递给
self.layer_instance
中底层层实例的调用方法的关键字参数。
- 返回
使用谱归一化权重的层的输出。
方法
- class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
L2 权重规范化 (https://arxiv.org/abs/1602.07868).
权重规范化将权重参数归一化,以便矩阵的 l2 范数等于 1。这实现为一个层包装器,其中每个包装层在其计算其
__call__
输出之前会对其参数进行 l2 归一化。示例用法
>>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Baz(nn.Module): ... @nn.compact ... def __call__(self, x): ... return nn.Dense(2)(x) >>> class Bar(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = Baz()(x) ... x = nn.Dense(3)(x) ... x = Baz()(x) ... x = nn.Dense(3)(x) ... return x >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... # l2-normalize all params of the second Dense layer ... x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) ... x = nn.Dense(5)(x) ... # l2-normalize all kernels in the Bar submodule and all params in ... # the Baz submodule ... x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables)) FrozenDict({ params: { Bar_0: { Baz_0: { Dense_0: { bias: (2,), kernel: (5, 2), }, }, Baz_1: { Dense_0: { bias: (2,), kernel: (3, 2), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (3,), kernel: (2, 3), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, WeightNorm_0: { Dense_1/bias/scale: (4,), Dense_1/kernel/scale: (4,), }, WeightNorm_1: { Bar_0/Baz_0/Dense_0/bias/scale: (2,), Bar_0/Baz_0/Dense_0/kernel/scale: (2,), Bar_0/Baz_1/Dense_0/bias/scale: (2,), Bar_0/Baz_1/Dense_0/kernel/scale: (2,), Bar_0/Dense_0/kernel/scale: (3,), Bar_0/Dense_1/kernel/scale: (3,), }, }, })
- layer_instance#
用 WeightNorm 包装的模块实例
- epsilon#
添加到 l2 规范化中的一个小浮点数,以避免除以零。
- 类型
float
- dtype#
结果的 dtype(默认值:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- use_scale#
如果为 True,则创建一个可学习的变量
scale
,该变量在 l2 规范化后乘以layer_instance
变量。- 类型
bool
- scale_init#
缩放函数的初始化函数。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- feature_axes#
特征轴维度。l2 范数是通过在剩余的(非特征)轴上减少
layer_instance
变量来计算的。因此,为每个指定的特征计算一个单独的 l2 范数值,并学习一个单独的比例因子(如果use_scale=True
)。默认情况下,尾随维度被视为特征轴。- 类型
Optional[Union[int, collections.abc.Sequence[int]]]
- variable_filter#
一个可选的迭代器,它包含字符串项目。WeightNorm 层将选择性地将 l2 规范化应用于
layer_instance
变量,其键路径(由“/”分隔)与variable_filter
相匹配。例如,variable_filter={'kernel'}
将仅将 l2 规范化应用于键路径包含“kernel”的变量。默认情况下,variable_filter={'kernel'}
。- 类型
collections.abc.Iterable | None
- __call__(*args, **kwargs)[source]#
计算
self.layer_instance
中权重的 l2 范数,并使用此值归一化权重,然后计算__call__
输出。- 参数
*args – 传递给
self.layer_instance
中底层层实例的调用方法的位置参数。**kwargs – 传递给
self.layer_instance
中底层层实例的调用方法的关键字参数。
- 返回
使用 l2 归一化权重的层的输出。
方法
组合器#
- class flax.linen.Sequential(layers, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
应用一串线性模块。
仅用于将可调用对象融合在一起的简单情况,其中特定模块/操作的输入是前一个模块/操作的输出。
模块将按照它们在构造函数中传递的顺序应用。
Sequential 的
__call__
方法接受任何输入并将其转发到它包含的第一个模块。 它将输出按顺序链接到下一个模块的输入,并返回最终模块的输出。示例用法
>>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... return nn.Sequential([nn.Dense(4), ... nn.relu, ... nn.Dense(2), ... nn.log_softmax])(x)
由于 Sequential.__call__ 是一个 紧凑 方法,如果需要形状推断,您也可以传递构建模块的内联函数
module = nn.Sequential([ # << more layers lambda x: SomeModule(x.shape[-1])(x), # shape inference # << more layers ])
此组合器还支持返回多个输出的层,如果作为元组或字典返回。 如果层的输出是一个
tuple
,它将在下一层中扩展为*args
,如果它是一个dict
,它将在下一层中扩展为**kwargs
。示例用法
>>> class CrossAttentionBlock(nn.Module): ... num_heads: int = 2 ... qkv_features: int = 16 ... ... @nn.compact ... def __call__(self, query, key_value): ... output = nn.MultiHeadDotProductAttention( ... num_heads=self.num_heads, qkv_features=self.qkv_features)(query, ... key_value) ... output = nn.Dense(self.qkv_features)(output) ... return dict(query=output, key_value=key_value) # also works for tuples >>> from typing import Sequence >>> class CrossAttentionNetwork(nn.Module): ... num_layers: Sequence[int] ... ... @nn.compact ... def __call__(self, x): ... return nn.Sequential([CrossAttentionBlock() for _ in ... range(self.num_layers)])(query, key_value)
- layers#
要按顺序应用的可调用对象的序列。
- 类型
collections.abc.Sequence[collections.abc.Callable[[…], Any]]
- 引发
ValueError – 如果 layers 不是序列。
方法
随机#
- class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
创建一个 dropout 层。
注意
使用
Module.apply()
时,请确保包含一个名为'dropout'
的 RNG 种子。 dropout 对变量初始化并不必要。示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class MLP(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(4)(x) ... x = nn.Dropout(0.5, deterministic=not train)(x) ... return x >>> model = MLP() >>> x = jnp.ones((1, 3)) >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout >>> model.apply(variables, x, train=False) # don't use dropout Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32) >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout Array([[ 0. , -1.1856356, -1.0369378, 0. ]], dtype=float32)
- rate#
dropout 概率。 (_不是_ 保留率!)
- 类型
float
- broadcast_dims#
将共享相同 dropout 掩码的维度
- 类型
collections.abc.Sequence[int]
- deterministic#
如果为 false,则输入将按
1 / (1 - rate)
缩放并掩盖,而如果为 true,则不应用掩码并按原样返回输入。- 类型
bool | None
- rng_collection#
请求 rng 键时要使用的 rng 集合名称。
- 类型
str
- __call__(inputs, deterministic=None, rng=None)[source]#
将随机 dropout 掩码应用于输入。
- 参数
inputs – 应该被随机掩盖的输入。
deterministic – 如果为 false,则输入将按
1 / (1 - rate)
缩放并掩盖,而如果为 true,则不应用掩码并按原样返回输入。rng – 一个可选的 PRNGKey 用作随机键,如果未指定,则将使用
make_rng
和rng_collection
名称生成一个。
- 返回
被掩盖的输入,重新加权以保留均值。
方法
注意力#
- class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
多头点积注意力。
示例用法
>>> import flax.linen as nn >>> import jax >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) >>> shape = (4, 3, 2, 5) >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) >>> variables = layer.init(jax.random.key(0), q) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer.apply(variables, q, k, v) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) >>> out = layer.apply(variables, q, k) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) >>> out = layer.apply(variables, q) >>> attention_kwargs = dict( ... num_heads=8, ... qkv_features=16, ... kernel_init=nn.initializers.ones, ... bias_init=nn.initializers.zeros, ... dropout_rate=0.5, ... deterministic=False, ... ) >>> class Module(nn.Module): ... attention_kwargs: dict ... ... @nn.compact ... def __call__(self, x, dropout_rng=None): ... out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... return out1, out2 >>> module = Module(attention_kwargs) >>> variables = module.init({'params': key1, 'dropout': key2}, q) >>> # out1 and out2 are different. >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) >>> # out3 and out4 are different. >>> # out1 and out3 are different. out2 and out4 are different. >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) >>> # out1 and out2 are the same. >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) >>> # out1 and out2 are the same as out3 and out4. >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
- num_heads#
注意头的数量。 特征(即 inputs_q.shape[-1])应该可以被头数整除。
- 类型
int
- dtype#
计算的 dtype(默认值:从输入和参数推断)
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认值:float32)
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- qkv_features#
键、查询和值的维度。
- 类型
int | None
- out_features#
最后投影的维度
- 类型
int | None
- broadcast_dropout#
使用沿批次维度的广播 dropout。
- 类型
bool
- dropout_rate#
Dropout 率。
- 类型
float
- deterministic#
如果为 False,则注意力权重将使用 dropout 随机掩盖,而如果为 True,则注意力权重是确定性的。
- 类型
bool | None
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
Dense 层内核的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- out_kernel_init#
输出 Dense 层内核的可选初始化器,如果为 None,则将使用
kernel_init
。- 类型
Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]
- bias_init#
Dense 层偏差的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- out_bias_init#
输出 Dense 层偏差的可选初始化器,如果为 None,则将使用
bias_init
。- 类型
Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]
- use_bias#
点式 QKVO 密集变换是否使用偏差。
- 类型
bool
- attention_fn#
dot_product_attention 或兼容函数。 接受查询、键、值并返回形状为
[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]
的输出- 类型
collections.abc.Callable[[…], Union[jax.Array, Any]]
- decode#
是否准备和使用自回归缓存。
- 类型
bool
- normalize_qk#
是否应该应用 QK 归一化(arxiv.org/abs/2302.05442)。
- 类型
bool
- qk_attn_weights_einsum_cls#
用于创建计算注意力权重的 einsum 的工厂函数。
- 类型
collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None
- attn_weights_value_einsum_cls#
用于创建计算注意力权重与值的乘积的 einsum 的工厂函数。
- 类型
collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#
对输入数据应用多头点积注意力。
将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。
如果 inputs_k 和 inputs_v 都为 None,它们将都复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。
- 参数
inputs_q – 形状为
[batch_sizes..., length, features]
的输入查询。inputs_k – 形状为
[batch_sizes..., length, features]
的键。如果为 None,inputs_k 将复制 inputs_q 的值。inputs_v – 形状为
[batch_sizes..., length, features]
的值。如果为 None,inputs_v 将复制 inputs_k 的值。inputs_kv – 形状为
[batch_sizes..., length, features]
的键/值。如果为 None,inputs_kv 将复制 inputs_q 的值。此参数将很快被弃用。使用 inputs_k 和 inputs_v 代替。mask – 形状为
[batch_sizes..., num_heads, query_length, key/value_length]
的注意力掩码。如果其相应的掩码值为False
,则会屏蔽掉注意力权重。deterministic – 如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。
dropout_rng – 传递给注意力层 dropout 掩码的可选 rng 密钥。否则,将使用 self.make_rng(‘dropout’) 代替。
sow_weights – 如果为
True
,则注意力权重会被播种到 ‘intermediates’ 集合中。请记住,通过mutable=['intermediates']
将 ‘intermediates’ 标记为可变,以便返回该集合。
- 返回
形状为
[batch_sizes..., length, features]
的输出。
方法
- class flax.linen.MultiHeadAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
多头点积注意力。
MultiHeadDotProductAttention
的别名。注意:
MultiHeadAttention
是MultiHeadDotProductAttention
的包装器,因此它们的实现相同。但是,默认情况下,MultiHeadAttention
层将被命名为MultiHeadAttention_{index}
,而MultiHeadDotProductAttention
将被命名为MultiHeadDotProductAttention_{index}
。因此,这可能会影响模块内的检查点、参数集合名称和 RNG 线程(因为在生成新的 RNG 时会使用层名称)。示例用法
>>> import flax.linen as nn >>> import jax >>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16) >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) >>> shape = (4, 3, 2, 5) >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) >>> variables = layer.init(jax.random.key(0), q) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer.apply(variables, q, k, v) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) >>> out = layer.apply(variables, q, k) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) >>> out = layer.apply(variables, q) >>> attention_kwargs = dict( ... num_heads=8, ... qkv_features=16, ... kernel_init=nn.initializers.ones, ... bias_init=nn.initializers.zeros, ... dropout_rate=0.5, ... deterministic=False, ... ) >>> class Module(nn.Module): ... attention_kwargs: dict ... ... @nn.compact ... def __call__(self, x, dropout_rng=None): ... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... return out1, out2 >>> module = Module(attention_kwargs) >>> variables = module.init({'params': key1, 'dropout': key2}, q) >>> # out1 and out2 are different. >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) >>> # out3 and out4 are different. >>> # out1 and out3 are different. out2 and out4 are different. >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) >>> # out1 and out2 are the same. >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) >>> # out1 and out2 are the same as out3 and out4. >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
- num_heads#
注意头的数量。特征(即 inputs_q.shape[-1])应能被头的数量整除。
- 类型
int
- dtype#
计算的 dtype(默认值:从输入和参数推断)
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认值:float32)
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- qkv_features#
键、查询和值的维度。
- 类型
int | None
- out_features#
最后投影的维度
- 类型
int | None
- broadcast_dropout#
bool:沿批次维度使用广播 dropout。
- 类型
bool
- dropout_rate#
dropout 率
- 类型
float
- deterministic#
如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。
- 类型
bool | None
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。- 类型
Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
Dense 层内核的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
Dense 层偏差的初始化器。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- use_bias#
bool:点态 QKVO 稠密变换是否使用偏差。
- 类型
bool
- attention_fn#
dot_product_attention 或兼容函数。 接受查询、键、值并返回形状为
[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]
的输出- 类型
collections.abc.Callable[[…], Union[jax.Array, Any]]
- decode#
是否准备和使用自回归缓存。
- 类型
bool
- normalize_qk#
是否应应用 QK 规范化(arxiv.org/abs/2302.05442)。
- 类型
bool
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)#
对输入数据应用多头点积注意力。
将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。
如果 inputs_k 和 inputs_v 都为 None,它们将都复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。
- 参数
inputs_q – 形状为
[batch_sizes..., length, features]
的输入查询。inputs_k – 形状为
[batch_sizes..., length, features]
的键。如果为 None,inputs_k 将复制 inputs_q 的值。inputs_v – 形状为
[batch_sizes..., length, features]
的值。如果为 None,inputs_v 将复制 inputs_k 的值。inputs_kv – 形状为
[batch_sizes..., length, features]
的键/值。如果为 None,inputs_kv 将复制 inputs_q 的值。此参数将很快被弃用。使用 inputs_k 和 inputs_v 代替。mask – 形状为
[batch_sizes..., num_heads, query_length, key/value_length]
的注意力掩码。如果其相应的掩码值为False
,则会屏蔽掉注意力权重。deterministic – 如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。
dropout_rng – 传递给注意力层 dropout 掩码的可选 rng 密钥。否则,将使用 self.make_rng(‘dropout’) 代替。
sow_weights – 如果为
True
,则注意力权重会被播种到 ‘intermediates’ 集合中。请记住,通过mutable=['intermediates']
将 ‘intermediates’ 标记为可变,以便返回该集合。
- 返回
形状为
[batch_sizes..., length, features]
的输出。
方法
- class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
多头点积自注意力的特殊情况。该层已被弃用,建议使用
MultiHeadDotProductAttention
。- 用法示例:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) >>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
- __call__(inputs_q, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#
在输入数据上应用多头点积自注意力。
将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。
- 参数
inputs_q – 形状为
[batch_sizes..., length, features]
的输入查询。mask – 形状为
[batch_sizes..., num_heads, query_length, key/value_length]
的注意力掩码。如果其相应的掩码值为False
,则会屏蔽掉注意力权重。deterministic – 如果为 false,则注意力权重会使用 dropout 随机掩码,而如果为 true,则注意力权重是确定性的。
- 返回
形状为
[batch_sizes..., length, features]
的输出。
方法
- flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, einsum=None)[source]#
计算给定查询和键的点积注意力权重。
用于
dot_product_attention()
,这是您最有可能使用的函数。但是,如果您想访问注意力权重以进行自省,那么您可以直接调用此函数并自行调用 einsum。- 参数
query – 用于计算注意力的查询,形状为
[batch..., q_length, num_heads, qk_depth_per_head]
。key – 用于计算注意力的键,形状为
[batch..., kv_length, num_heads, qk_depth_per_head]
。bias – 注意力权重的偏差。这应该可以广播到形状
[batch..., num_heads, q_length, kv_length]
。这可用于合并因果掩码、填充掩码、邻近偏差等。mask – 注意力权重的掩码。这应该可以广播到形状
[batch..., num_heads, q_length, kv_length]
。这可用于合并因果掩码。如果相应的掩码值为False
,则注意力权重会被屏蔽。broadcast_dropout – bool: 在批次维度上使用广播的丢弃。
dropout_rng – JAX PRNGKey: 用于丢弃
dropout_rate – 丢弃率
deterministic – bool,确定性还是非确定性(是否应用丢弃)
dtype – 计算的数据类型(默认值:从输入和参数推断)
precision – 计算的数值精度,请参阅
jax.lax.Precision
了解详细信息。module – 将注意力权重播种到 ‘intermediates’ 集合中的模块。请记住通过
mutable=['intermediates']
将 ‘intermediates’ 标记为可变,以便返回该集合。如果module
为 None,则不会播种注意力权重。force_fp32_for_softmax – bool,是否强制以 fp32 计算 softmax。这对混合精度训练很有用,在混合精度训练中,需要更高的精度才能确保数值稳定性。
einsum_dot_general – 在 einsum 中使用的 dot_general。
einsum – 如果未指定,将使用默认的 jnp.einsum。此参数与 precision 和 einsum_dot_general 互斥。
- 引发
ValueError – 如果同时指定了 precision/einsum_dot_general 和 einsum。
- 返回
形状为
[batch..., num_heads, q_length, kv_length]
的输出。
- flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, qk_attn_weights_einsum=None, attn_weights_value_einsum=None)[source]#
计算给定查询、键和值的点积注意力。
这是基于 https://arxiv.org/abs/1706.03762 应用注意力的核心函数。它计算给定查询和键的注意力权重,并使用注意力权重组合值。
注意
query
、key
和value
不需要任何批次维度。- 参数
query – 用于计算注意力的查询,形状为
[batch..., q_length, num_heads, qk_depth_per_head]
。key – 用于计算注意力的键,形状为
[batch..., kv_length, num_heads, qk_depth_per_head]
。value – 用于注意力的值,形状为
[batch..., kv_length, num_heads, v_depth_per_head]
。bias – 注意力权重的偏差。这应该可以广播到形状
[batch..., num_heads, q_length, kv_length]
。这可用于合并因果掩码、填充掩码、邻近偏差等。mask – 注意力权重的掩码。这应该可以广播到形状
[batch..., num_heads, q_length, kv_length]
。这可用于合并因果掩码。如果相应的掩码值为False
,则注意力权重会被屏蔽。broadcast_dropout – bool: 在批次维度上使用广播的丢弃。
dropout_rng – JAX PRNGKey: 用于丢弃
dropout_rate – 丢弃率
deterministic – bool,确定性还是非确定性(是否应用丢弃)
dtype – 计算的数据类型(默认值:从输入推断)
precision – 计算的数值精度,请参阅 ``jax.lax.Precision` 了解详细信息。
module – 将注意力权重播种到 ‘intermediates’ 集合中的模块。请记住通过
mutable=['intermediates']
将 ‘intermediates’ 标记为可变,以便返回该集合。如果module
为 None,则不会播种注意力权重。force_fp32_for_softmax – bool,是否强制以 fp32 计算 softmax。这对混合精度训练很有用,在混合精度训练中,需要更高的精度才能确保数值稳定性。
einsum_dot_general – 在 jnp.einsum 中使用的 dot_general。
qk_attn_weights_einsum – 用于计算注意力权重的 einsum。未指定时,将使用默认的 jnp.einsum。此参数与 precision 和 einsum_dot_general 互斥。
attn_weights_value_einsum – 用于计算注意力权重和值的乘积的 einsum。未指定时,将使用默认的 jnp.einsum。此参数与 precision 和 einsum_dot_general 互斥。
- 返回
输出形状为
[batch..., q_length, num_heads, v_depth_per_head]
。- 引发
ValueError – 如果同时指定了 precision/einsum_dot_general 和
qk_attn_weights_einsum。
- flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
用于注意力权重的掩码生成助手。
对于一维输入(即,
[batch..., len_q]
,[batch..., len_kv]
),注意力权重将是[batch..., heads, len_q, len_kv]
,此函数将生成[batch..., 1, len_q, len_kv]
。- 参数
query_input – 查询长度大小的批处理扁平输入
key_input – 密钥长度大小的批处理扁平输入
pairwise_fn – 广播逐元素比较函数
extra_batch_dims – 要添加单例轴的额外批次维数,默认情况下为零
dtype – 掩码返回的 dtype
- 返回
用于一维注意力的
[batch..., 1, len_q, len_kv]
形状的掩码。
- flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
为自注意力生成因果掩码。
对于一维输入(即,
[batch..., len]
),自注意力权重将是[batch..., heads, len, len]
,此函数将生成形状为[batch..., 1, len, len]
的因果掩码。- 参数
x – 形状为
[batch..., len]
的输入数组extra_batch_dims – 要添加单例轴的批次维数,默认情况下为零
dtype – 掩码返回的 dtype
- 返回
用于一维注意力的
[batch..., 1, len, len]
形状的因果掩码。
循环#
- class flax.linen.RNNCellBase(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
RNN 单元基类。
- __call__(**kwargs)#
将 self 作为函数调用。
- initialize_carry(rng, input_shape)[source]#
初始化 RNN 单元载体。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的初始化载体。
方法
initialize_carry
(rng, input_shape)初始化 RNN 单元载体。
- class flax.linen.LSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
LSTM 单元。
该单元的数学定义如下
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]其中 x 是输入,h 是前一时间步的输出,c 是内存。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.LSTMCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
输出特征数。
- 类型
int
- gate_fn#
用于门的激活函数(默认:sigmoid)。
- 类型
collections.abc.Callable[[…], Any]
- activation_fn#
用于输出和内存更新的激活函数(默认:tanh)。
- 类型
collections.abc.Callable[[…], Any]
- kernel_init#
用于转换输入的内核的初始化函数(默认:lecun_normal)。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- recurrent_kernel_init#
用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差参数的初始化器(默认:initializers.zeros_init())
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- __call__(carry, inputs)[source]#
长短期记忆(LSTM)单元。
- 参数
carry – LSTM 单元的隐藏状态,使用
LSTMCell.initialize_carry
初始化。inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。
- 返回
包含新载体和输出的元组。
- initialize_carry(rng, input_shape)[source]#
初始化 RNN 单元载体。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的初始化载体。
方法
initialize_carry
(rng, input_shape)初始化 RNN 单元载体。
- class flax.linen.OptimizedLSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
更高效的 LSTM 单元,它在 matmul 之前连接状态组件。
参数与
LSTMCell
兼容。请注意,只要隐藏大小大致 <= 2048 个单元,此单元通常比LSTMCell
更快。该单元的数学定义与
LSTMCell
相同,如下所示\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]其中 x 是输入,h 是前一时间步的输出,c 是内存。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.OptimizedLSTMCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- gate_fn#
用于门的激活函数(默认:sigmoid)。
- 类型
collections.abc.Callable[[…], Any]
- activation_fn#
用于输出和内存更新的激活函数(默认:tanh)。
- 类型
collections.abc.Callable[[…], Any]
- kernel_init#
用于转换输入的内核的初始化函数(默认:lecun_normal)。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- recurrent_kernel_init#
用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差参数的初始化器(默认:initializers.zeros_init())。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- __call__(carry, inputs)[source]#
一个优化的长短期记忆(LSTM)单元。
- 参数
carry – LSTM 单元的隐藏状态,使用
LSTMCell.initialize_carry
初始化。inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。
- 返回
包含新载体和输出的元组。
- initialize_carry(rng, input_shape)[source]#
初始化 RNN 单元载体。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的初始化载体。
方法
initialize_carry
(rng, input_shape)初始化 RNN 单元载体。
- class flax.linen.ConvLSTMCell(features, kernel_size, strides=None, padding='SAME', use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
一个卷积 LSTM 单元。
该实现基于 xingjian2015convolutional。给定 x_t 和先前状态 (h_{t-1}, c_{t-1}),核心计算
\[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]其中 * 表示卷积运算;i_t、f_t、o_t 是输入、遗忘和输出门激活,g_t 是单元更新的向量。
注意
- 遗忘门初始化
遵循 jozefowicz2015empirical,我们在初始化后将 1.0 添加到 b_f 以减少训练开始时遗忘的规模。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (3, 5, 5)) >>> layer = nn.ConvLSTMCell(features=4, kernel_size=(2, 2)) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
卷积滤波器的数量。
- 类型
int
- kernel_size#
卷积核的形状。
- 类型
collections.abc.Sequence[int]
- strides#
一个
n
个整数的序列,表示窗口间步长。- 类型
collections.abc.Sequence[int] | None
- padding#
字符串
'SAME'
、字符串'VALID'
或n
个(low, high)
整数对的序列,表示要应用于每个空间维度的前后的填充。- 类型
str | collections.abc.Sequence[tuple[int, int]]
- bias#
是否在输出中添加偏差(默认:True)。
- dtype#
计算的 dtype(默认:None)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- __call__(carry, inputs)[source]#
构造一个卷积 LSTM。
- 参数
carry – Conv2DLSTM 单元的隐藏状态,使用
Conv2DLSTM.initialize_carry
初始化。inputs – 输入数据,维度为 (batch, spatial_dims…, features)。
- 返回
包含新载体和输出的元组。
- initialize_carry(rng, input_shape)[source]#
初始化 RNN 单元载体。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的初始化载体。
方法
initialize_carry
(rng, input_shape)初始化 RNN 单元载体。
- class flax.linen.SimpleCell(features, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, residual=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
简单单元。
该单元的数学定义如下
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]其中 x 是输入,h 是前一时间步的输出。
如果 residual 为 True,
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.SimpleCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
输出特征数。
- 类型
int
- activation_fn#
用于输出和内存更新的激活函数(默认:tanh)。
- 类型
collections.abc.Callable[[…], Any]
- kernel_init#
用于转换输入的内核的初始化函数(默认:lecun_normal)。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- recurrent_kernel_init#
用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差参数的初始化器(默认:initializers.zeros_init())
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- dtype#
计算的 dtype(默认:None)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- residual#
预激活残差连接 (https://arxiv.org/abs/1801.06105).
- 类型
bool
- __call__(carry, inputs)[source]#
简单单元。
- 参数
carry – Simple 单元的隐藏状态,使用
SimpleCell.initialize_carry
初始化。inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。
- 返回
包含新载体和输出的元组。
- initialize_carry(rng, input_shape)[source]#
初始化 RNN 单元载体。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的初始化载体。
方法
initialize_carry
(rng, input_shape)初始化 RNN 单元载体。
- class flax.linen.GRUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
GRU 单元。
该单元的数学定义如下
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]其中 x 是输入,h 是前一时间步的输出。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.GRUCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
输出特征数。
- 类型
int
- gate_fn#
用于门的激活函数(默认:sigmoid)。
- 类型
collections.abc.Callable[[…], Any]
- activation_fn#
用于输出和内存更新的激活函数(默认:tanh)。
- 类型
collections.abc.Callable[[…], Any]
- kernel_init#
用于转换输入的内核的初始化函数(默认:lecun_normal)。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- recurrent_kernel_init#
用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- bias_init#
偏差参数的初始化器(默认:initializers.zeros_init())
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- dtype#
计算的 dtype(默认:None)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- __call__(carry, inputs)[source]#
门控循环单元 (GRU) 单元。
- 参数
carry – GRU 单元的隐藏状态,使用
GRUCell.initialize_carry
初始化。inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。
- 返回
包含新载体和输出的元组。
- initialize_carry(rng, input_shape)[source]#
初始化 RNN 单元载体。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的初始化载体。
方法
initialize_carry
(rng, input_shape)初始化 RNN 单元载体。
- class flax.linen.MGUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, forget_bias_init=<function ones>, activation_bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, reset_gate=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
MGU 单元 (https://arxiv.org/pdf/1603.09420.pdf).
该单元的数学定义如下
\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + f * (W_{hn} h + b_{hn})) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]其中 x 是输入,h 是前一时间步的输出。
如果
reset_gate
为假,则上述公式变为\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + W_{hn} h) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.MGUCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
输出特征数。
- 类型
int
- gate_fn#
用于门的激活函数(默认:sigmoid)。
- 类型
collections.abc.Callable[[…], Any]
- activation_fn#
用于输出和内存更新的激活函数(默认:tanh)。
- 类型
collections.abc.Callable[[…], Any]
- kernel_init#
用于转换输入的内核的初始化函数(默认:lecun_normal)。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- recurrent_kernel_init#
用于转换隐藏状态的内核的初始化函数(默认:initializers.orthogonal())。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- forget_bias_init#
遗忘门偏差参数的初始化器。默认设置为 initializers.ones_init(),因为这可以防止梯度消失。有关更多详细信息,请参见 https://proceedings.mlr.press/v37/jozefowicz15.pdf,第 2.2 节。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- activation_bias_init#
激活输出的偏差参数的初始化器(默认:initializers.zeros_init())。
- 类型
Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]
- dtype#
计算的 dtype(默认:None)。
- 类型
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- reset_gate#
应用重置门的标志。
- 类型
bool
- __call__(carry, inputs)[source]#
最小门控单元 (MGU) 单元。
- 参数
carry – MGU 单元的隐藏状态,使用
MGUCell.initialize_carry
初始化。inputs – 一个 ndarray,包含当前时间步的输入。除最后一个以外的所有维度都被视为批处理维度。
- 返回
包含新载体和输出的元组。
- initialize_carry(rng, input_shape)[source]#
初始化 RNN 单元载体。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的初始化载体。
方法
initialize_carry
(rng, input_shape)初始化 RNN 单元载体。
- class flax.linen.RNN(cell, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, variable_axes=FrozenDict({}), variable_broadcast='params', variable_carry=False, split_rngs=FrozenDict({ params: False, }), parent=<flax.linen.module._Sentinel object>, name=None)[source]#
The
RNN
module takes anyRNNCellBase
instance and applies it over a sequenceusing
flax.linen.scan()
.示例
>>> import jax.numpy as jnp >>> import jax >>> import flax.linen as nn >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64)) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (batch, time, cell_size) (10, 50, 64)
如上所示,RNN 使用
cell_size
参数为单元的initialize_carry
方法设置size
参数,实际上,这通常是你想要的单元的隐藏单元数量。但是,这可能根据你使用的单元而有所不同,例如ConvLSTMCell
需要size
参数,形式为(kernel_height, kernel_width, features)
>>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features) >>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3))) >>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x) >>> y.shape # (batch, time, height, width, features) (10, 50, 32, 32, 64)
默认情况下,RNN 期望时间维度位于批次维度之后 (
(*batch, time, *features)
),如果设置time_major=True
,则 RNN 将期望时间维度位于开头 ((time, *batch, *features)
)>>> x = jnp.ones((50, 10, 32)) # (time, batch, features) >>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (time, batch, cell_size) (50, 10, 64)
输出是一个形状为
(*batch, time, *cell_size)
的数组(通常),但是如果设置return_carry=True
,则它将返回最终的 carry 和输出的元组。>>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True) >>> variables = lstm.init(jax.random.key(0), x) >>> carry, y = lstm.apply(variables, x) >>> jax.tree_util.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size)) ((10, 64), (10, 64)) >>> y.shape # (batch, time, cell_size) (10, 50, 64)
为了支持可变长度序列,你可以传递一个
seq_lengths
,它是一个形状为(*batch)
的整数数组,其中每个元素都是批次中序列的长度。例如>>> seq_lengths = jnp.array([3, 2, 5])
对应于填充元素的输出元素不会被清零。如果将
return_carry
设置为True
,则 carry 将是每个序列的最后一个有效元素的状态。RNN 还接受
flax.linen.scan()
的一些参数,默认情况下它们被设置为与像LSTMCell
和GRUCell
这样的单元一起使用,但可以根据需要覆盖它们。覆盖扫描的默认值如下所示>>> lstm = nn.RNN( ... nn.LSTMCell(64), ... unroll=1, variable_axes={}, variable_broadcast='params', ... variable_carry=False, split_rngs={'params': False})
- cell#
一个
RNNCellBase
的实例。
- time_major#
如果
time_major=False
(默认)它将期望输入的形状为(*batch, time, *features)
,否则它将期望输入的形状为(time, *batch, *features)
。- 类型
bool
- return_carry#
如果
return_carry=False
(默认)只返回输出序列,否则将返回最终承载和输出序列的元组。- 类型
bool
- reverse#
如果
reverse=False
(默认)序列从左到右处理并以原始顺序返回,否则它将从右到左处理,并以相反顺序返回。如果传递了seq_lengths
,填充将始终保留在序列的末尾。- 类型
bool
- keep_order#
如果
keep_order=True
,当reverse=True
时,输出将在处理后被反转回原始顺序,这对于对齐双向 RNN 中的序列很有用。如果keep_order=False
(默认),输出将保留在由reverse
指定的顺序。- 类型
bool
- unroll#
在循环的单次迭代中展开多少次扫描迭代,默认为 1。此参数将传递给
nn.scan
。- 类型
int
- variable_axes#
一个字典,将每个集合映射到一个整数
i
(表示我们扫描维度i
)或None
(复制而不是扫描)。此参数将转发到nn.scan
。- 类型
collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.typing.In[int], flax.typing.Out[int]]]
- variable_broadcast#
指定广播的变量集合。广播变量不应依赖于任何无法从循环中提取的计算。这通常用于在 fn 内定义共享参数。此参数将转发到
nn.scan
。- 类型
Union[bool, str, Collection[str], DenyList]
- variable_carry#
指定在循环中传递的变量集合。对这些变量的修改将被传递到下一轮迭代,并且在扫描结束时将被保留。此参数将转发到
nn.scan
。- 类型
Union[bool, str, Collection[str], DenyList]
- split_rngs#
一个映射,从 PRNGSequenceFilter 到 bool,指定一个集合的 PRNG 密钥是否应该被拆分,以便其值在每个步骤中都不同,或者被复制,以便其值在每个步骤中都保持相同。此参数将转发到
nn.scan
。- 类型
collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], bool]
- __call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
将 RNN 应用于输入。
__call__
允许您选择性地覆盖某些属性,例如return_carry
和time_major
,这些属性是在构造函数中定义的。- 参数
inputs – 输入序列。
initial_carry – 初始承载,如果未提供,它将使用单元的
RNNCellBase.initialize_carry()
方法初始化。init_key – 用于初始化承载的 PRNG 密钥,如果未提供,将使用
jax.random.key(0)
。大多数单元将忽略此参数。seq_lengths – 一个可选的形状为
(*batch)
的整数数组,指示每个序列的长度,时间维度中索引大于相应长度的元素将被视为填充并被忽略。return_carry – 如果
return_carry=False
(默认)只返回输出序列,否则将返回最终承载和输出序列的元组。time_major – 如果
time_major=False
(默认)它将期望输入的形状为(*batch, time, *features)
,否则它将期望输入的形状为(time, *batch, *features)
。reverse – 覆盖
reverse
属性,如果reverse=False
(默认)序列从左到右处理并以原始顺序返回,否则它将从右到左处理,并以相反顺序返回。如果传递了seq_lengths
,填充将始终保留在序列的末尾。keep_order – 覆盖
keep_order
属性,如果keep_order=True
,当reverse=True
时,输出将在处理后被反转回原始顺序,这对于对齐双向 RNN 中的序列很有用。如果keep_order=False
(默认),输出将保留在由reverse
指定的顺序。
- 返回
如果
return_carry=False
(默认)只返回输出序列,否则将返回最终承载和输出序列的元组。
方法
- class flax.linen.Bidirectional(forward_rnn, backward_rnn, merge_fn=<function _concatenate>, time_major=False, return_carry=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
双向处理输入并合并结果。
示例用法
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Bidirectional(nn.RNN(nn.GRUCell(4)), nn.RNN(nn.GRUCell(4))) >>> x = jnp.ones((2, 3)) >>> variables = layer.init(jax.random.key(0), x) >>> out = layer.apply(variables, x)
- __call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
将 self 作为函数调用。
方法
BatchApply#
- class flax.linen.BatchApply(f, num_dims=2)[source]#
临时合并输入张量的领先维度。
将张量的领先维度合并为单个维度,运行给定的可调用对象,然后拆分结果的领先维度以匹配输入。
排名小于要折叠的维度数量的输入数组将被原样传递。
这对于将模块应用于例如
[Time, Batch, ...]
数组的每个时间步可能很有用。对于某些
f
和平台,这可能比jax.vmap()
更有效,尤其是在与其他转换(如jax.grad()
)结合使用时。示例用法
>>> import jax, jax.numpy as jnp >>> a = jax.random.normal(jax.random.key(0), [2, 3, 4]) >>> b = jax.random.normal(jax.random.key(1), [4]) >>> def raises(a, b): ... if len(a.shape) != 2: ... raise ValueError("a must be shape 2") ... if len(b.shape) != 1: ... raise ValueError("b must be shape 1") ... return jnp.dot(a, b) >>> out = BatchApply(raises)(a, b) >>> expected_merged_leading = raises(a.reshape(2*3, 4), b) >>> expected = expected_merged_leading.reshape((2, 3) + expected_merged_leading.shape[1:]) >>> np.testing.assert_array_equal(out, expected)
方法