线性

线性#

NNX 线性层类。

class flax.nnx.Conv(*args, **kwargs)[源代码]#

卷积模块,封装 lax.conv_general_dilated

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 6, 4)

>>> # circular padding with stride 2
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3),
...                  strides=2, padding='CIRCULAR', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 4, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

int 或包含输入特征数量的元组。

out_features#

int 或包含输出特征数量的元组。

kernel_size#

卷积核的形状。对于 1D 卷积,内核大小可以作为整数传递,这将解释为单个整数的元组。对于所有其他情况,它必须是整数序列。

strides#

一个整数或一个 n 个整数的序列,表示窗口之间的步长(默认值:1)。

padding#

字符串 'SAME'、字符串 'VALID'、字符串 'CIRCULAR'(周期性边界条件)或 n(low, high) 整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个整数被解释为在所有维度中应用相同的填充,并且在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL' 填充用于 1D 卷积将左填充卷积轴,从而产生相同大小的输出。

input_dilation#

一个整数或一个 n 个整数的序列,给出要应用于 inputs 的每个空间维度的扩张因子(默认值:1)。具有输入扩张 d 的卷积等效于具有步长 d 的转置卷积。

kernel_dilation#

一个整数或一个 n 个整数的序列,给出要应用于卷积核的每个空间维度的扩张因子(默认值:1)。具有核扩张的卷积也称为“空洞卷积”。

feature_group_count#

整数,默认值为 1。如果指定,则将输入特征分为若干组。

use_bias#

是否在输出中添加偏置(默认值:True)。

mask#

掩码卷积期间权重的可选掩码。掩码必须与卷积权重矩阵的形状相同。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参见 jax.lax.Precision

kernel_init#

卷积核的初始化器。

bias_init#

偏置的初始化器。

rngs#

rng 键。

__call__(inputs)[源代码]#

将(可能未共享的)卷积应用于输入。

参数

inputs – 输入数据,其维度为 (*batch_dims, spatial_dims..., features)。这是通道最后约定,即 2D 卷积为 NHWC,3D 卷积为 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度都会被展平为单个维度进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 该层可能会比这种默认展平方法产生更好的性能。如果输入缺少批次维度,则会为卷积添加该维度,并在返回时删除该维度,从而允许编写单个示例代码。

返回

卷积后的数据。

方法

class flax.nnx.ConvTranspose(*args, **kwargs)[源代码]#

卷积模块,封装 lax.conv_transpose

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,),
...                           padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 10, 4)

>>> # circular padding with stride 2
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6),
...                           strides=(2, 2), padding='CIRCULAR',
...                           transpose_kernel=True, rngs=rngs)
>>> layer.kernel.value.shape
(6, 6, 4, 3)
>>> layer.bias.value.shape
(4,)
>>> out = layer(jnp.ones((1, 15, 15, 3)))
>>> out.shape
(1, 30, 30, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

int 或包含输入特征数量的元组。

out_features#

int 或包含输出特征数量的元组。

kernel_size#

卷积核的形状。对于 1D 卷积,内核大小可以作为整数传递,这将解释为单个整数的元组。对于所有其他情况,它必须是整数序列。

strides#

一个整数或一个 n 个整数的序列,表示窗口之间的步长(默认值:1)。

padding#

字符串 'SAME'、字符串 'VALID'、字符串 'CIRCULAR'(周期性边界条件)或 n(low, high) 整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个整数被解释为在所有维度中应用相同的填充,并且在序列中传递单个整数会导致在两侧使用相同的填充。'CAUSAL' 填充用于 1D 卷积将左填充卷积轴,从而产生相同大小的输出。

kernel_dilation#

一个整数或一个 n 个整数的序列,给出要应用于卷积核的每个空间维度的扩张因子(默认值:1)。具有核扩张的卷积也称为“空洞卷积”。

use_bias#

是否在输出中添加偏置(默认值:True)。

mask#

掩码卷积期间权重的可选掩码。掩码必须与卷积权重矩阵的形状相同。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参见 jax.lax.Precision

kernel_init#

卷积核的初始化器。

bias_init#

偏置的初始化器。

transpose_kernel#

如果 True,则翻转空间轴并交换内核的输入/输出通道轴。

rngs#

rng 键。

__call__(inputs)[源代码]#

将转置卷积应用于输入。

行为与 jax.lax.conv_transpose 类似。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features)。这是通道最后(channels-last)的约定,例如,对于二维卷积是 NHWC,对于三维卷积是 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入有多个批次维度,所有批次维度将展平为一个维度进行卷积,并在返回之前恢复。在某些情况下,直接对层进行 vmap 操作可能比此默认展平方法产生更好的性能。如果输入缺少批次维度,则将添加一个批次维度用于卷积,并在返回时删除,以便编写单样本代码。

返回

卷积后的数据。

方法

class flax.nnx.Embed(*args, **kwargs)[源代码]#

嵌入模块。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'embedding': VariableState(
    type=Param,
    value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
           [ 0.01070483,  0.27923733,  1.7487359 ],
           [ 0.59161806,  0.8660184 ,  1.2838588 ],
           [-0.748139  , -0.15856352,  0.06061118],
           [-0.4769059 , -0.6607095 ,  0.46697947]], dtype=float32)
  )
})
>>> # get the first three and last three embeddings
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> layer(indices_input)
Array([[[-0.90411377, -0.3648777 , -1.1083648 ],
        [ 0.01070483,  0.27923733,  1.7487359 ],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]],

       [[-0.4769059 , -0.6607095 ,  0.46697947],
        [-0.748139  , -0.15856352,  0.06061118],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]]], dtype=float32)

从整数 [0, num_embeddings) 到 features 维向量的参数化函数。此 Module 将创建一个形状为 (num_embeddings, features)embedding 矩阵。调用此层时,输入值将用于对 embedding 矩阵进行 0 索引。索引大于或等于 num_embeddings 的值将导致 nan 值。当 num_embeddings 等于 1 时,它会将 embedding 矩阵广播到具有附加 features 维度的输入形状。

num_embeddings#

嵌入数量/词汇大小。

features#

每个嵌入的特征维度数。

dtype#

嵌入向量的数据类型(默认:与嵌入相同)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

embedding_init#

嵌入初始化器。

rngs#

rng 键。

__call__(inputs)[源代码]#

沿最后一个维度嵌入输入。

参数

inputs – 输入数据,所有维度都被视为批次维度。输入数组中的值必须是整数。

返回

嵌入输入数据的输出。输出形状遵循输入,并附加一个额外的 features 维度。

attend(query)[源代码]#

使用查询数组在嵌入上进行注意力计算。

参数

query – 最后一个维度等于嵌入的特征深度 features 的数组。

返回

一个数组,其最终维度为 num_embeddings,对应于查询向量数组与每个嵌入的批次内积。通常用于 NLP 模型中嵌入和 logits 变换之间的权重共享。

方法

attend(query)

使用查询数组在嵌入上进行注意力计算。

class flax.nnx.Linear(*args, **kwargs)[源代码]#

应用于输入的最后一个维度的线性变换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(4,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(3, 4)
  )
})
in_features#

输入特征的数量。

out_features#

输出特征的数量。

use_bias#

是否在输出中添加偏置(默认值:True)。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参见 jax.lax.Precision

kernel_init#

权重矩阵的初始化函数。

bias_init#

偏置的初始化函数。

dot_general#

点积函数。

rngs#

rng 键。

__call__(inputs)[源代码]#

沿最后一个维度将线性变换应用于输入。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.nnx.LinearGeneral(*args, **kwargs)[源代码]#

具有灵活轴的线性变换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> # equivalent to `nnx.Linear(2, 4)`
>>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4)
>>> # output features (4, 5)
>>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> # apply transformation on the the second and last axes
>>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 3, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> y = layer(jnp.ones((16, 2, 3)))
>>> y.shape
(16, 4, 5)
in_features#

int 或包含输入特征数量的元组。

out_features#

int 或包含输出特征数量的元组。

axis#

要在其上应用变换的轴的整数或元组。例如,(-2, -1) 将变换应用于最后两个轴。

batch_axis#

批次轴索引到轴大小的映射。

use_bias#

是否在输出中添加偏置(默认值:True)。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

kernel_init#

权重矩阵的初始化函数。

bias_init#

偏置的初始化函数。

precision#

计算的数值精度,有关详细信息,请参见 jax.lax.Precision

rngs#

rng 键。

__call__(inputs)[源代码]#

将线性变换应用于输入的多个维度。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.nnx.Einsum(*args, **kwargs)[源代码]#

一个带有可学习内核和偏置的 einsum 转换。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(8, 2, 4)
>>> layer.bias.value.shape
(8, 4)
>>> y = layer(jnp.ones((16, 11, 2)))
>>> y.shape
(16, 11, 8, 4)
einsum_str#

一个表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧操作数是传入的输入,右侧操作数是可学习的内核。构造函数参数和调用参数中的 einsum_str 必须有一个不为 None,而另一个必须为 None。

kernel_shape#

内核的形状。

bias_shape#

偏置的形状。如果为 None,则不会使用偏置。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参见 jax.lax.Precision

kernel_init#

权重矩阵的初始化函数。

bias_init#

偏置的初始化函数。

rngs#

rng 键。

__call__(inputs, einsum_str=None)[源代码]#

沿最后一个维度将线性变换应用于输入。

参数
  • inputs – 要变换的 nd 数组。

  • einsum_str – 一个表示 einsum 方程的字符串。该方程必须恰好有两个操作数,左侧操作数是传入的输入,右侧操作数是可学习的内核。构造函数参数和调用参数中的 einsum_str 必须有一个不为 None,而另一个必须为 None。

返回

变换后的输入。

方法