注意力#
- class flax.nnx.MultiHeadAttention(*args, **kwargs)[源代码]#
多头注意力。
使用示例
>>> from flax import nnx >>> import jax >>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16, ... decode=False, rngs=nnx.Rngs(0)) >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> shape = (4, 3, 2, 5) >>> q, k, v = ( ... jax.random.uniform(key1, shape), ... jax.random.uniform(key2, shape), ... jax.random.uniform(key3, shape), ... ) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer(q, k, v) >>> # equivalent output when inferring v >>> assert (layer(q, k) == layer(q, k, k)).all() >>> # equivalent output when inferring k and v >>> assert (layer(q) == layer(q, q)).all() >>> assert (layer(q) == layer(q, q, q)).all()
- num_heads#
注意力头的数量。特征(即 inputs_q.shape[-1])应该可以被头的数量整除。
- in_features#
整数或包含输入特征数量的元组。
- qkv_features#
键、查询和值的维度。
- out_features#
最后投影的维度
- dtype#
计算的数据类型(默认:从输入和参数推断)
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)
- broadcast_dropout#
布尔值:使用沿批次维度的广播 dropout。
- dropout_rate#
dropout 率
- deterministic#
如果为 false,则使用 dropout 随机屏蔽注意力权重,如果为 true,则注意力权重是确定性的。
- precision#
计算的数值精度,有关详细信息,请参见jax.lax.Precision。
- kernel_init#
用于密集层内核的初始化器。
- out_kernel_init#
用于输出密集层内核的可选初始化器,如果为 None,则使用 kernel_init。
- bias_init#
用于密集层偏置的初始化器。
- out_bias_init#
用于输出密集层偏置的可选初始化器,如果为 None,则使用 bias_init。
- use_bias#
布尔值:点式 QKVO 密集变换是否使用偏置。
- attention_fn#
dot_product_attention 或兼容函数。接受查询、键、值,并返回形状为 [bs, dim1, dim2, …, dimN,, num_heads, value_channels]` 的输出
- decode#
是否准备并使用自回归缓存。
- normalize_qk#
是否应该应用 QK 归一化 (arxiv.org/abs/2302.05442)。
- rngs#
rng 键。
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[源代码]#
对输入数据应用多头点积注意力。
将输入投影到多头查询、键和值向量中,应用点积注意力并将结果投影到输出向量。
如果 inputs_k 和 inputs_v 均为 None,则它们都将复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。
- 参数
inputs_q – 形状为 [batch_sizes…, length, features] 的输入查询。
inputs_k – 形状为 [batch_sizes…, length, features] 的键。如果为 None,则 inputs_k 将复制 inputs_q 的值。
inputs_v – 形状为 [batch_sizes…, length, features] 的值。如果为 None,则 inputs_v 将复制 inputs_k 的值。
mask – 形状为 [batch_sizes…, num_heads, query_length, key/value_length] 的注意力掩码。如果其对应的掩码值为 False,则会屏蔽掉注意力权重。
deterministic – 如果为 false,则使用 dropout 随机屏蔽注意力权重,如果为 true,则注意力权重是确定性的。传递到调用方法的
deterministic
标志将优先于传递到构造函数的deterministic
标志。rngs – rng 键。传递到调用方法的 rng 键将优先于传递到构造函数的 rng 键。
sow_weights – 如果为
True
,则注意力权重将播种到“intermediates”集合中。decode – 是否准备并使用自回归缓存。传递到调用方法的
decode
标志将优先于传递到构造函数的decode
标志。
- 返回
形状为 [batch_sizes…, length, features] 的输出。
- init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[源代码]#
初始化用于快速自回归解码的缓存。当
decode=True
时,必须先调用此方法,然后再执行前向推理。在解码模式下,一次只能传递一个令牌。使用示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> batch_size = 5 >>> embed_dim = 3 >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x)
方法
init_cache
(input_shape[, dtype])初始化用于快速自回归解码的缓存。
- flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[源代码]#
组合注意力掩码。
- 参数
*masks – 要组合的注意力掩码参数集合,其中一些可以是 None。
dtype – 返回的掩码的数据类型。
- 返回
组合后的掩码,通过逻辑与运算缩减,如果没有给定掩码则返回 None。
- flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[源代码]#
计算给定查询(query)、键(key)和值(value)的点积注意力。
这是基于 https://arxiv.org/abs/1706.03762 应用注意力的核心函数。它计算给定查询和键的注意力权重,并使用注意力权重组合值。
注意
query
,key
,value
不需要任何批次维度。- 参数
query – 用于计算注意力的查询,形状为
[batch..., q_length, num_heads, qk_depth_per_head]
。key – 用于计算注意力的键,形状为
[batch..., kv_length, num_heads, qk_depth_per_head]
。value – 在注意力中使用的值,形状为
[batch..., kv_length, num_heads, v_depth_per_head]
。bias – 注意力权重的偏置。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码、填充掩码、邻近偏置等。
mask – 注意力权重的掩码。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码。如果注意力权重对应的掩码值为 False,则会被屏蔽。
broadcast_dropout – bool: 沿批次维度使用广播 dropout。
dropout_rng – JAX PRNGKey: 用于 dropout
dropout_rate – dropout 率
deterministic – bool,确定性或非确定性(用于应用 dropout)
dtype – 计算的数据类型 (默认: 从输入推断)
precision – 计算的数值精度,请参阅 jax.lax.Precision 了解详细信息。
module – 将注意力权重注入
nnx.Intermediate
集合的模块。如果module
为 None,则不会注入注意力权重。
- 返回
输出的形状为 [batch…, q_length, num_heads, v_depth_per_head]。
- flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#
用于注意力权重的掩码创建助手。
对于 1d 输入 (例如,[batch…, len_q], [batch…, len_kv]),注意力权重将为 [batch…, heads, len_q, len_kv],此函数将生成 [batch…, 1, len_q, len_kv]。
- 参数
query_input – 查询长度大小的批处理、扁平输入
key_input – 键长度大小的批处理、扁平输入
pairwise_fn – 广播逐元素比较函数
extra_batch_dims – 添加单例轴的额外批次维度数,默认为无
dtype – 掩码返回数据类型
- 返回
用于 1d 注意力的形状为 [batch…, 1, len_q, len_kv] 的掩码。
- flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#
为自注意力创建因果掩码。
对于 1d 输入 (例如,[batch…, len]),自注意力权重将为 [batch…, heads, len, len],此函数将生成形状为 [batch…, 1, len, len] 的因果掩码。
- 参数
x – 形状为 [batch…, len] 的输入数组
extra_batch_dims – 添加单例轴的批次维度数,默认为无
dtype – 掩码返回数据类型
- 返回
用于 1d 注意力的形状为 [batch…, 1, len, len] 的因果掩码。