线性#
NNX 线性层类。
- class flax.nnx.Conv(*args, **kwargs)[源代码]#
卷积模块,封装
lax.conv_general_dilated
。用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), ... strides=2, padding='CIRCULAR', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
int 或包含输入特征数量的元组。
- out_features#
int 或包含输出特征数量的元组。
- kernel_size#
卷积核的形状。对于 1D 卷积,内核大小可以作为整数传递,这将解释为单个整数的元组。对于所有其他情况,它必须是整数序列。
- strides#
一个整数或一个
n
个整数的序列,表示窗口之间的步长(默认值:1)。
- padding#
字符串
'SAME'
、字符串'VALID'
、字符串'CIRCULAR'
(周期性边界条件)或n
个(low, high)
整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个整数被解释为在所有维度中应用相同的填充,并且在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL'
填充用于 1D 卷积将左填充卷积轴,从而产生相同大小的输出。
- input_dilation#
一个整数或一个
n
个整数的序列,给出要应用于inputs
的每个空间维度的扩张因子(默认值:1)。具有输入扩张d
的卷积等效于具有步长d
的转置卷积。
- kernel_dilation#
一个整数或一个
n
个整数的序列,给出要应用于卷积核的每个空间维度的扩张因子(默认值:1)。具有核扩张的卷积也称为“空洞卷积”。
- feature_group_count#
整数,默认值为 1。如果指定,则将输入特征分为若干组。
- use_bias#
是否在输出中添加偏置(默认值:True)。
- mask#
掩码卷积期间权重的可选掩码。掩码必须与卷积权重矩阵的形状相同。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参见
jax.lax.Precision
。
- kernel_init#
卷积核的初始化器。
- bias_init#
偏置的初始化器。
- rngs#
rng 键。
- __call__(inputs)[源代码]#
将(可能未共享的)卷积应用于输入。
- 参数
inputs – 输入数据,其维度为
(*batch_dims, spatial_dims..., features)
。这是通道最后约定,即 2D 卷积为 NHWC,3D 卷积为 NDHWC。注意:这与lax.conv_general_dilated
使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度都会被展平为单个维度进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 该层可能会比这种默认展平方法产生更好的性能。如果输入缺少批次维度,则会为卷积添加该维度,并在返回时删除该维度,从而允许编写单个示例代码。- 返回
卷积后的数据。
方法
- class flax.nnx.ConvTranspose(*args, **kwargs)[源代码]#
卷积模块,封装
lax.conv_transpose
。用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6), ... strides=(2, 2), padding='CIRCULAR', ... transpose_kernel=True, rngs=rngs) >>> layer.kernel.value.shape (6, 6, 4, 3) >>> layer.bias.value.shape (4,) >>> out = layer(jnp.ones((1, 15, 15, 3))) >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
int 或包含输入特征数量的元组。
- out_features#
int 或包含输出特征数量的元组。
- kernel_size#
卷积核的形状。对于 1D 卷积,内核大小可以作为整数传递,这将解释为单个整数的元组。对于所有其他情况,它必须是整数序列。
- strides#
一个整数或一个
n
个整数的序列,表示窗口之间的步长(默认值:1)。
- padding#
字符串
'SAME'
、字符串'VALID'
、字符串'CIRCULAR'
(周期性边界条件)或n
个(low, high)
整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个整数被解释为在所有维度中应用相同的填充,并且在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL'
填充用于 1D 卷积将左填充卷积轴,从而产生相同大小的输出。
- kernel_dilation#
一个整数或一个
n
个整数的序列,给出要应用于卷积核的每个空间维度的扩张因子(默认值:1)。具有核扩张的卷积也称为“空洞卷积”。
- use_bias#
是否在输出中添加偏置(默认值:True)。
- mask#
掩码卷积期间权重的可选掩码。掩码必须与卷积权重矩阵的形状相同。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参见
jax.lax.Precision
。
- kernel_init#
卷积核的初始化器。
- bias_init#
偏置的初始化器。
- transpose_kernel#
如果
True
,则翻转空间轴并交换内核的输入/输出通道轴。
- rngs#
rng 键。
- __call__(inputs)[源代码]#
将转置卷积应用于输入。
行为与
jax.lax.conv_transpose
类似。- 参数
inputs – 输入数据,维度为
(*batch_dims, spatial_dims..., features)
。这是通道最后(channels-last)的约定,例如,对于二维卷积是 NHWC,对于三维卷积是 NDHWC。注意:这与lax.conv_general_dilated
使用的输入约定不同,后者将空间维度放在最后。注意:如果输入有多个批次维度,所有批次维度将展平为一个维度进行卷积,并在返回之前恢复。在某些情况下,直接对层进行 vmap 操作可能比此默认展平方法产生更好的性能。如果输入缺少批次维度,则将添加一个批次维度用于卷积,并在返回时删除,以便编写单样本代码。- 返回
卷积后的数据。
方法
- class flax.nnx.Embed(*args, **kwargs)[源代码]#
嵌入模块。
用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'embedding': VariableState( type=Param, value=Array([[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ], [-0.748139 , -0.15856352, 0.06061118], [-0.4769059 , -0.6607095 , 0.46697947]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) Array([[[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ]], [[-0.4769059 , -0.6607095 , 0.46697947], [-0.748139 , -0.15856352, 0.06061118], [ 0.59161806, 0.8660184 , 1.2838588 ]]], dtype=float32)
从整数 [0,
num_embeddings
) 到features
维向量的参数化函数。此Module
将创建一个形状为(num_embeddings, features)
的embedding
矩阵。调用此层时,输入值将用于对embedding
矩阵进行 0 索引。索引大于或等于num_embeddings
的值将导致nan
值。当num_embeddings
等于 1 时,它会将embedding
矩阵广播到具有附加features
维度的输入形状。- num_embeddings#
嵌入数量/词汇大小。
- features#
每个嵌入的特征维度数。
- dtype#
嵌入向量的数据类型(默认:与嵌入相同)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- embedding_init#
嵌入初始化器。
- rngs#
rng 键。
- __call__(inputs)[源代码]#
沿最后一个维度嵌入输入。
- 参数
inputs – 输入数据,所有维度都被视为批次维度。输入数组中的值必须是整数。
- 返回
嵌入输入数据的输出。输出形状遵循输入,并附加一个额外的
features
维度。
- attend(query)[源代码]#
使用查询数组在嵌入上进行注意力计算。
- 参数
query – 最后一个维度等于嵌入的特征深度
features
的数组。- 返回
一个数组,其最终维度为
num_embeddings
,对应于查询向量数组与每个嵌入的批次内积。通常用于 NLP 模型中嵌入和 logits 变换之间的权重共享。
方法
attend
(query)使用查询数组在嵌入上进行注意力计算。
- class flax.nnx.Linear(*args, **kwargs)[源代码]#
应用于输入的最后一个维度的线性变换。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(4,) ), 'kernel': VariableState( type=Param, value=(3, 4) ) })
- in_features#
输入特征的数量。
- out_features#
输出特征的数量。
- use_bias#
是否在输出中添加偏置(默认值:True)。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参见
jax.lax.Precision
。
- kernel_init#
权重矩阵的初始化函数。
- bias_init#
偏置的初始化函数。
- dot_general#
点积函数。
- rngs#
rng 键。
方法
- class flax.nnx.LinearGeneral(*args, **kwargs)[源代码]#
具有灵活轴的线性变换。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4) >>> # output features (4, 5) >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4, 5) >>> layer.bias.value.shape (4, 5) >>> # apply transformation on the the second and last axes >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 3, 4, 5) >>> layer.bias.value.shape (4, 5) >>> y = layer(jnp.ones((16, 2, 3))) >>> y.shape (16, 4, 5)
- in_features#
int 或包含输入特征数量的元组。
- out_features#
int 或包含输出特征数量的元组。
- axis#
要在其上应用变换的轴的整数或元组。例如,(-2, -1) 将变换应用于最后两个轴。
- batch_axis#
批次轴索引到轴大小的映射。
- use_bias#
是否在输出中添加偏置(默认值:True)。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- kernel_init#
权重矩阵的初始化函数。
- bias_init#
偏置的初始化函数。
- precision#
计算的数值精度,有关详细信息,请参见
jax.lax.Precision
。
- rngs#
rng 键。
方法
- class flax.nnx.Einsum(*args, **kwargs)[源代码]#
一个带有可学习内核和偏置的 einsum 转换。
用法示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (8, 2, 4) >>> layer.bias.value.shape (8, 4) >>> y = layer(jnp.ones((16, 11, 2))) >>> y.shape (16, 11, 8, 4)
- einsum_str#
一个表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧操作数是传入的输入,右侧操作数是可学习的内核。构造函数参数和调用参数中的
einsum_str
必须有一个不为 None,而另一个必须为 None。
- kernel_shape#
内核的形状。
- bias_shape#
偏置的形状。如果为 None,则不会使用偏置。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参见
jax.lax.Precision
。
- kernel_init#
权重矩阵的初始化函数。
- bias_init#
偏置的初始化函数。
- rngs#
rng 键。
- __call__(inputs, einsum_str=None)[源代码]#
沿最后一个维度将线性变换应用于输入。
- 参数
inputs – 要变换的 nd 数组。
einsum_str – 一个表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧操作数是传入的输入,右侧操作数是可学习的内核。构造函数参数和调用参数中的
einsum_str
必须有一个不为 None,而另一个必须为 None。
- 返回
变换后的输入。
方法