归一化#
- class flax.nnx.BatchNorm(*args, **kwargs)[源代码]#
BatchNorm 模块。
为了计算输入的批归一化并更新批统计信息,请调用
train()
方法(或者在构造函数或调用时传入use_running_average=False
)。要使用存储的批统计信息的运行平均值,请调用
eval()
方法(或者在构造函数或调用时传入use_running_average=True
)。用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5, ... dtype=jnp.float32, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(6,) ), 'mean': VariableState( type=BatchStat, value=(6,) ), 'scale': VariableState( type=Param, value=(6,) ), 'var': VariableState( type=BatchStat, value=(6,) ) }) >>> # calculate batch norm on input and update batch statistics >>> layer.train() >>> y = layer(x) >>> batch_stats1 = nnx.state(layer, nnx.BatchStat) >>> y = layer(x) >>> batch_stats2 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all() >>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all() >>> # use stored batch statistics' running average >>> layer.eval() >>> y = layer(x) >>> batch_stats3 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all() >>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
- num_features#
输入特征的数量。
- use_running_average#
如果为 True,则将使用存储的批统计信息,而不是计算输入的批统计信息。
- axis#
输入的特征轴或非批处理轴。
- momentum#
批统计信息的指数移动平均的衰减率。
- epsilon#
添加到方差中的一个小浮点数,以避免除以零。
- dtype#
结果的 dtype(默认:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- use_bias#
如果为 True,则添加偏差(beta)。
- use_scale#
如果为 True,则乘以比例(gamma)。当下一层是线性层(例如 nn.relu)时,可以禁用此选项,因为缩放将由下一层完成。
- bias_init#
偏差的初始化器,默认情况下为零。
- scale_init#
比例的初始化器,默认情况下为一。
- axis_name#
用于组合来自多个设备的批统计信息的轴名称。有关轴名称的描述,请参见
jax.pmap
(默认:None)。
- axis_index_groups#
该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见jax.lax.psum
。
- use_fast_variance#
如果为 true,则使用更快但数值稳定性较差的方差计算方法。
- rngs#
rng 密钥。
- __call__(x, use_running_average=None, *, mask=None)[源代码]#
使用批统计信息对输入进行归一化。
- 参数
x – 要归一化的输入。
use_running_average – 如果为 true,则将使用存储的批统计信息,而不是计算输入的批统计信息。传递到调用方法中的
use_running_average
标志将优先于传递到构造函数中的use_running_average
标志。
- 返回
归一化的输入(与输入相同的形状)。
方法
- class flax.nnx.LayerNorm(*args, **kwargs)[源代码]#
层归一化(https://arxiv.org/abs/1607.06450)。
LayerNorm 独立地对批处理中每个给定示例的层的激活进行归一化,而不是像批归一化那样跨批处理。也就是说,应用一个转换,使每个示例内的平均激活保持接近 0,并且激活标准差接近 1。
用法示例
>>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': VariableState( type=Param, value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': VariableState( type=Param, value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x)
- num_features#
输入特征的数量。
- epsilon#
添加到方差中的一个小浮点数,以避免除以零。
- dtype#
结果的 dtype(默认:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- use_bias#
如果为 True,则添加偏差(beta)。
- use_scale#
如果为 True,则乘以比例(gamma)。当下一层是线性层(例如 nnx.relu)时,可以禁用此选项,因为缩放将由下一层完成。
- bias_init#
偏差的初始化器,默认情况下为零。
- scale_init#
比例的初始化器,默认情况下为一。
- reduction_axes#
用于计算归一化统计信息的轴。
- feature_axes#
用于学习偏差和缩放的特征轴。
- axis_name#
用于组合来自多个设备的批统计信息的轴名称。有关轴名称的描述,请参见
jax.pmap
(默认:None)。只有当模型细分为跨设备的子模型时,才需要此选项,即被归一化的数组在 pmap 中的设备之间分片。
- axis_index_groups#
该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见jax.lax.psum
。
- use_fast_variance#
如果为 true,则使用更快但数值稳定性较差的方差计算方法。
- rngs#
rng 密钥。
方法
- class flax.nnx.RMSNorm(*args, **kwargs)[source]#
RMS层归一化 (https://arxiv.org/abs/1910.07467)。
RMSNorm独立地对批次中每个给定示例的层激活进行归一化,而不是像批量归一化那样跨批次进行归一化。与LayerNorm将均值重新居中为0并按激活的标准差进行归一化不同,RMSNorm根本不重新居中,而是按激活的均方根进行归一化。
用法示例
>>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'scale': VariableState( type=Param, value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x)
- num_features#
输入特征的数量。
- epsilon#
添加到方差中的一个小浮点数,以避免除以零。
- dtype#
结果的 dtype(默认:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- use_scale#
如果为 True,则乘以 scale (gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用此选项,因为缩放将由下一层完成。
- scale_init#
比例的初始化器,默认情况下为一。
- reduction_axes#
用于计算归一化统计信息的轴。
- feature_axes#
用于学习偏差和缩放的特征轴。
- axis_name#
用于组合来自多个设备的批统计信息的轴名称。有关轴名称的描述,请参见
jax.pmap
(默认:None)。只有当模型细分为跨设备的子模型时,才需要此选项,即被归一化的数组在 pmap 中的设备之间分片。
- axis_index_groups#
该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见jax.lax.psum
。
- use_fast_variance#
如果为 true,则使用更快但数值稳定性较差的方差计算方法。
- rngs#
rng 密钥。
方法
- class flax.nnx.GroupNorm(*args, **kwargs)[source]#
组归一化 (arxiv.org/abs/1803.08494)。
此操作类似于批量归一化,但统计信息在大小相等的通道组之间共享,而不是跨批次维度共享。因此,组归一化不依赖于批次组成,并且不需要维护内部状态来存储统计信息。用户应指定通道组的总数或每个组的通道数。
注意
LayerNorm 是 GroupNorm 的一种特殊情况,其中
num_groups=1
。用法示例
>>> from flax import nnx >>> import jax >>> import numpy as np ... >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': VariableState( type=Param, value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': VariableState( type=Param, value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) ... >>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x) >>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2)
- num_features#
输入特征/通道的数量。
- num_groups#
通道组的总数。原始组归一化论文建议的默认值为 32。
- group_size#
一个组中的通道数。
- epsilon#
添加到方差中的一个小浮点数,以避免除以零。
- dtype#
结果的 dtype(默认:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- use_bias#
如果为 True,则添加偏差(beta)。
- use_scale#
如果为 True,则乘以 scale (gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用此选项,因为缩放将由下一层完成。
- bias_init#
偏差的初始化器,默认情况下为零。
- scale_init#
比例的初始化器,默认情况下为一。
- reduction_axes#
用于计算归一化统计信息的轴的列表。此列表必须包含最后一个维度,该维度被假定为特征轴。此外,如果调用时使用的输入与用于初始化的数据相比具有额外的领先轴(例如,由于批处理),则需要显式定义缩减轴。
- axis_name#
用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的描述,请参见
jax.pmap
(默认值:None)。仅当模型在设备之间细分时才需要此操作,即被归一化的数组在 pmap 或 shard map 内的设备之间分片。对于 SPMD jit,您无需手动同步。只需确保轴被正确注释,XLA:SPMD 将插入必要的集合。
- axis_index_groups#
该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,
[[0, 1], [2, 3]]
将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见jax.lax.psum
。
- use_fast_variance#
如果为 true,则使用更快但数值稳定性较差的方差计算方法。
- rngs#
rng 密钥。
- __call__(x, *, mask=None)[source]#
将组归一化应用于输入 (arxiv.org/abs/1803.08494)。
- 参数
x – 形状为
...self.num_features
的输入,其中self.num_features
是通道维度,...
表示可用于累积统计信息的任意数量的额外维度。如果未指定缩减轴,则所有额外的维度...
将用于累积统计信息,除了被假定为表示批次的领先维度。mask – 可广播到
inputs
张量的二进制数组,指示应计算均值和方差的位置。
- 返回
归一化的输入(与输入相同的形状)。
方法