LoRA#

NNX LoRA 类。

class flax.nnx.LoRA(*args, **kwargs)[源代码]#

一个独立的 LoRA 层。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0))
>>> layer.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> # Wrap around existing layer
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1))
>>> assert wrapper.base_module == linear
>>> wrapper.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> y = layer(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
in_features#

输入特征的数量。

lora_rank#

LoRA 维度的秩。

out_features#

输出特征的数量。

base_module#

一个基础模块,用于调用和替换(如果可能)。

dtype#

计算的数据类型(默认:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认:float32)。

precision#

计算的数值精度,详情请参见jax.lax.Precision

kernel_init#

权重矩阵的初始化函数。

lora_param_type#

LoRA 参数的类型。

__call__(x)[源代码]#

将自身作为函数调用。

方法

class flax.nnx.LoRALinear(*args, **kwargs)[源代码]#

一个 nnx.Linear 层,其中输出将进行 LoRA 化。

模型状态结构将与 Linear 的状态结构兼容。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0))
>>> linear.kernel.value.shape
(3, 4)
>>> lora_linear.kernel.value.shape
(3, 4)
>>> lora_linear.lora.lora_a.value.shape
(3, 2)
>>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value)
Array(True, dtype=bool)
>>> y = lora_linear(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
in_features#

输入特征的数量。

out_features#

输出特征的数量。

lora_rank#

LoRA 维度的秩。

base_module#

一个基础模块,用于调用和替换(如果可能)。

dtype#

计算的数据类型(默认:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认:float32)。

precision#

计算的数值精度,详情请参见jax.lax.Precision

kernel_init#

权重矩阵的初始化函数。

lora_param_type#

LoRA 参数的类型。

__call__(x)[源代码]#

沿最后一个维度对输入应用线性变换。

参数

inputs – 要转换的 nd 数组。

返回

变换后的输入。

方法