归一化#

class flax.nnx.BatchNorm(*args, **kwargs)[源代码]#

BatchNorm 模块。

为了计算输入的批归一化并更新批统计信息,请调用 train() 方法(或者在构造函数或调用时传入 use_running_average=False)。

要使用存储的批统计信息的运行平均值,请调用 eval() 方法(或者在构造函数或调用时传入 use_running_average=True)。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5,
...                       dtype=jnp.float32, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(6,)
  ),
  'mean': VariableState(
    type=BatchStat,
    value=(6,)
  ),
  'scale': VariableState(
    type=Param,
    value=(6,)
  ),
  'var': VariableState(
    type=BatchStat,
    value=(6,)
  )
})

>>> # calculate batch norm on input and update batch statistics
>>> layer.train()
>>> y = layer(x)
>>> batch_stats1 = nnx.state(layer, nnx.BatchStat)
>>> y = layer(x)
>>> batch_stats2 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all()
>>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all()

>>> # use stored batch statistics' running average
>>> layer.eval()
>>> y = layer(x)
>>> batch_stats3 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
num_features#

输入特征的数量。

use_running_average#

如果为 True,则将使用存储的批统计信息,而不是计算输入的批统计信息。

axis#

输入的特征轴或非批处理轴。

momentum#

批统计信息的指数移动平均的衰减率。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

use_bias#

如果为 True,则添加偏差(beta)。

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性层(例如 nn.relu)时,可以禁用此选项,因为缩放将由下一层完成。

bias_init#

偏差的初始化器,默认情况下为零。

scale_init#

比例的初始化器,默认情况下为一。

axis_name#

用于组合来自多个设备的批统计信息的轴名称。有关轴名称的描述,请参见 jax.pmap(默认:None)。

axis_index_groups#

该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算方法。

rngs#

rng 密钥。

__call__(x, use_running_average=None, *, mask=None)[源代码]#

使用批统计信息对输入进行归一化。

参数
  • x – 要归一化的输入。

  • use_running_average – 如果为 true,则将使用存储的批统计信息,而不是计算输入的批统计信息。传递到调用方法中的 use_running_average 标志将优先于传递到构造函数中的 use_running_average 标志。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.nnx.LayerNorm(*args, **kwargs)[源代码]#

层归一化(https://arxiv.org/abs/1607.06450)。

LayerNorm 独立地对批处理中每个给定示例的层的激活进行归一化,而不是像批归一化那样跨批处理。也就是说,应用一个转换,使每个示例内的平均激活保持接近 0,并且激活标准差接近 1。

用法示例

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

输入特征的数量。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

use_bias#

如果为 True,则添加偏差(beta)。

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性层(例如 nnx.relu)时,可以禁用此选项,因为缩放将由下一层完成。

bias_init#

偏差的初始化器,默认情况下为零。

scale_init#

比例的初始化器,默认情况下为一。

reduction_axes#

用于计算归一化统计信息的轴。

feature_axes#

用于学习偏差和缩放的特征轴。

axis_name#

用于组合来自多个设备的批统计信息的轴名称。有关轴名称的描述,请参见 jax.pmap(默认:None)。只有当模型细分为跨设备的子模型时,才需要此选项,即被归一化的数组在 pmap 中的设备之间分片。

axis_index_groups#

该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算方法。

rngs#

rng 密钥。

__call__(x, *, mask=None)[source]#

对输入应用层归一化。

参数

x – 输入

返回

归一化的输入(与输入相同的形状)。

方法

class flax.nnx.RMSNorm(*args, **kwargs)[source]#

RMS层归一化 (https://arxiv.org/abs/1910.07467)。

RMSNorm独立地对批次中每个给定示例的层激活进行归一化,而不是像批量归一化那样跨批次进行归一化。与LayerNorm将均值重新居中为0并按激活的标准差进行归一化不同,RMSNorm根本不重新居中,而是按激活的均方根进行归一化。

用法示例

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

输入特征的数量。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

use_scale#

如果为 True,则乘以 scale (gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用此选项,因为缩放将由下一层完成。

scale_init#

比例的初始化器,默认情况下为一。

reduction_axes#

用于计算归一化统计信息的轴。

feature_axes#

用于学习偏差和缩放的特征轴。

axis_name#

用于组合来自多个设备的批统计信息的轴名称。有关轴名称的描述,请参见 jax.pmap(默认:None)。只有当模型细分为跨设备的子模型时,才需要此选项,即被归一化的数组在 pmap 中的设备之间分片。

axis_index_groups#

该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算方法。

rngs#

rng 密钥。

__call__(x, mask=None)[source]#

对输入应用层归一化。

参数

x – 输入

返回

归一化的输入(与输入相同的形状)。

方法

class flax.nnx.GroupNorm(*args, **kwargs)[source]#

组归一化 (arxiv.org/abs/1803.08494)。

此操作类似于批量归一化,但统计信息在大小相等的通道组之间共享,而不是跨批次维度共享。因此,组归一化不依赖于批次组成,并且不需要维护内部状态来存储统计信息。用户应指定通道组的总数或每个组的通道数。

注意

LayerNorm 是 GroupNorm 的一种特殊情况,其中 num_groups=1

用法示例

>>> from flax import nnx
>>> import jax
>>> import numpy as np
...
>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})
>>> y = layer(x)
...
>>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x)
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)
num_features#

输入特征/通道的数量。

num_groups#

通道组的总数。原始组归一化论文建议的默认值为 32。

group_size#

一个组中的通道数。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认:float32)。

use_bias#

如果为 True,则添加偏差(beta)。

use_scale#

如果为 True,则乘以 scale (gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用此选项,因为缩放将由下一层完成。

bias_init#

偏差的初始化器,默认情况下为零。

scale_init#

比例的初始化器,默认情况下为一。

reduction_axes#

用于计算归一化统计信息的轴的列表。此列表必须包含最后一个维度,该维度被假定为特征轴。此外,如果调用时使用的输入与用于初始化的数据相比具有额外的领先轴(例如,由于批处理),则需要显式定义缩减轴。

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的描述,请参见 jax.pmap(默认值:None)。仅当模型在设备之间细分时才需要此操作,即被归一化的数组在 pmap 或 shard map 内的设备之间分片。对于 SPMD jit,您无需手动同步。只需确保轴被正确注释,XLA:SPMD 将插入必要的集合。

axis_index_groups#

该命名轴内的一组轴索引,表示要在其上进行缩减的设备子集(默认:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算方法。

rngs#

rng 密钥。

__call__(x, *, mask=None)[source]#

将组归一化应用于输入 (arxiv.org/abs/1803.08494)。

参数
  • x – 形状为 ...self.num_features 的输入,其中 self.num_features 是通道维度,... 表示可用于累积统计信息的任意数量的额外维度。如果未指定缩减轴,则所有额外的维度 ... 将用于累积统计信息,除了被假定为表示批次的领先维度。

  • mask – 可广播到 inputs 张量的二进制数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法