注意力#

class flax.nnx.MultiHeadAttention(*args, **kwargs)[源代码]#

多头注意力。

使用示例

>>> from flax import nnx
>>> import jax

>>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16,
...                                decode=False, rngs=nnx.Rngs(0))
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = (
...   jax.random.uniform(key1, shape),
...   jax.random.uniform(key2, shape),
...   jax.random.uniform(key3, shape),
... )

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer(q, k, v)
>>> # equivalent output when inferring v
>>> assert (layer(q, k) == layer(q, k, k)).all()
>>> # equivalent output when inferring k and v
>>> assert (layer(q) == layer(q, q)).all()
>>> assert (layer(q) == layer(q, q, q)).all()
num_heads#

注意力头的数量。特征(即 inputs_q.shape[-1])应该可以被头的数量整除。

in_features#

整数或包含输入特征数量的元组。

qkv_features#

键、查询和值的维度。

out_features#

最后投影的维度

dtype#

计算的数据类型(默认:从输入和参数推断)

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)

broadcast_dropout#

布尔值:使用沿批次维度的广播 dropout。

dropout_rate#

dropout 率

deterministic#

如果为 false,则使用 dropout 随机屏蔽注意力权重,如果为 true,则注意力权重是确定性的。

precision#

计算的数值精度,有关详细信息,请参见jax.lax.Precision

kernel_init#

用于密集层内核的初始化器。

out_kernel_init#

用于输出密集层内核的可选初始化器,如果为 None,则使用 kernel_init。

bias_init#

用于密集层偏置的初始化器。

out_bias_init#

用于输出密集层偏置的可选初始化器,如果为 None,则使用 bias_init。

use_bias#

布尔值:点式 QKVO 密集变换是否使用偏置。

attention_fn#

dot_product_attention 或兼容函数。接受查询、键、值,并返回形状为 [bs, dim1, dim2, …, dimN,, num_heads, value_channels]` 的输出

decode#

是否准备并使用自回归缓存。

normalize_qk#

是否应该应用 QK 归一化 (arxiv.org/abs/2302.05442)。

rngs#

rng 键。

__call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[源代码]#

对输入数据应用多头点积注意力。

将输入投影到多头查询、键和值向量中,应用点积注意力并将结果投影到输出向量。

如果 inputs_k 和 inputs_v 均为 None,则它们都将复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。

参数
  • inputs_q – 形状为 [batch_sizes…, length, features] 的输入查询。

  • inputs_k – 形状为 [batch_sizes…, length, features] 的键。如果为 None,则 inputs_k 将复制 inputs_q 的值。

  • inputs_v – 形状为 [batch_sizes…, length, features] 的值。如果为 None,则 inputs_v 将复制 inputs_k 的值。

  • mask – 形状为 [batch_sizes…, num_heads, query_length, key/value_length] 的注意力掩码。如果其对应的掩码值为 False,则会屏蔽掉注意力权重。

  • deterministic – 如果为 false,则使用 dropout 随机屏蔽注意力权重,如果为 true,则注意力权重是确定性的。传递到调用方法的 deterministic 标志将优先于传递到构造函数的 deterministic 标志。

  • rngs – rng 键。传递到调用方法的 rng 键将优先于传递到构造函数的 rng 键。

  • sow_weights – 如果为 True,则注意力权重将播种到“intermediates”集合中。

  • decode – 是否准备并使用自回归缓存。传递到调用方法的 decode 标志将优先于传递到构造函数的 decode 标志。

返回

形状为 [batch_sizes…, length, features] 的输出。

init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[源代码]#

初始化用于快速自回归解码的缓存。当 decode=True 时,必须先调用此方法,然后再执行前向推理。在解码模式下,一次只能传递一个令牌。

使用示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> batch_size = 5
>>> embed_dim = 3
>>> x = jnp.ones((batch_size, 1, embed_dim)) # single token
...
>>> model_nnx = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(42),
... )
...
>>> # out_nnx = model_nnx(x)  <-- throws an error because cache isn't initialized
...
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)

方法

init_cache(input_shape[, dtype])

初始化用于快速自回归解码的缓存。

flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[源代码]#

组合注意力掩码。

参数
  • *masks – 要组合的注意力掩码参数集合,其中一些可以是 None。

  • dtype – 返回的掩码的数据类型。

返回

组合后的掩码,通过逻辑与运算缩减,如果没有给定掩码则返回 None。

flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[源代码]#

计算给定查询(query)、键(key)和值(value)的点积注意力。

这是基于 https://arxiv.org/abs/1706.03762 应用注意力的核心函数。它计算给定查询和键的注意力权重,并使用注意力权重组合值。

注意

query, key, value 不需要任何批次维度。

参数
  • query – 用于计算注意力的查询,形状为 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用于计算注意力的键,形状为 [batch..., kv_length, num_heads, qk_depth_per_head]

  • value – 在注意力中使用的值,形状为 [batch..., kv_length, num_heads, v_depth_per_head]

  • bias – 注意力权重的偏置。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码、填充掩码、邻近偏置等。

  • mask – 注意力权重的掩码。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码。如果注意力权重对应的掩码值为 False,则会被屏蔽。

  • broadcast_dropout – bool: 沿批次维度使用广播 dropout。

  • dropout_rng – JAX PRNGKey: 用于 dropout

  • dropout_rate – dropout 率

  • deterministic – bool,确定性或非确定性(用于应用 dropout)

  • dtype – 计算的数据类型 (默认: 从输入推断)

  • precision – 计算的数值精度,请参阅 jax.lax.Precision 了解详细信息。

  • module – 将注意力权重注入 nnx.Intermediate 集合的模块。如果 module 为 None,则不会注入注意力权重。

返回

输出的形状为 [batch…, q_length, num_heads, v_depth_per_head]

flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#

用于注意力权重的掩码创建助手。

对于 1d 输入 (例如,[batch…, len_q], [batch…, len_kv]),注意力权重将为 [batch…, heads, len_q, len_kv],此函数将生成 [batch…, 1, len_q, len_kv]

参数
  • query_input – 查询长度大小的批处理、扁平输入

  • key_input – 键长度大小的批处理、扁平输入

  • pairwise_fn – 广播逐元素比较函数

  • extra_batch_dims – 添加单例轴的额外批次维度数,默认为无

  • dtype – 掩码返回数据类型

返回

用于 1d 注意力的形状为 [batch…, 1, len_q, len_kv] 的掩码。

flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#

为自注意力创建因果掩码。

对于 1d 输入 (例如,[batch…, len]),自注意力权重将为 [batch…, heads, len, len],此函数将生成形状为 [batch…, 1, len, len] 的因果掩码。

参数
  • x – 形状为 [batch…, len] 的输入数组

  • extra_batch_dims – 添加单例轴的批次维度数,默认为无

  • dtype – 掩码返回数据类型

返回

用于 1d 注意力的形状为 [batch…, 1, len, len] 的因果掩码。