LoRA#
NNX LoRA 类。
- class flax.nnx.LoRA(*args, **kwargs)[源代码]#
一个独立的 LoRA 层。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0)) >>> layer.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> # Wrap around existing layer >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1)) >>> assert wrapper.base_module == linear >>> wrapper.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> y = layer(jnp.ones((16, 3))) >>> y.shape (16, 4)
- in_features#
输入特征的数量。
- lora_rank#
LoRA 维度的秩。
- out_features#
输出特征的数量。
- base_module#
一个基础模块,用于调用和替换(如果可能)。
- dtype#
计算的数据类型(默认:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认:float32)。
- precision#
计算的数值精度,详情请参见jax.lax.Precision。
- kernel_init#
权重矩阵的初始化函数。
- lora_param_type#
LoRA 参数的类型。
方法
- class flax.nnx.LoRALinear(*args, **kwargs)[源代码]#
一个 nnx.Linear 层,其中输出将进行 LoRA 化。
模型状态结构将与 Linear 的状态结构兼容。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0)) >>> linear.kernel.value.shape (3, 4) >>> lora_linear.kernel.value.shape (3, 4) >>> lora_linear.lora.lora_a.value.shape (3, 2) >>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value) Array(True, dtype=bool) >>> y = lora_linear(jnp.ones((16, 3))) >>> y.shape (16, 4)
- in_features#
输入特征的数量。
- out_features#
输出特征的数量。
- lora_rank#
LoRA 维度的秩。
- base_module#
一个基础模块,用于调用和替换(如果可能)。
- dtype#
计算的数据类型(默认:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认:float32)。
- precision#
计算的数值精度,详情请参见jax.lax.Precision。
- kernel_init#
权重矩阵的初始化函数。
- lora_param_type#
LoRA 参数的类型。
方法