初始化器#

Flax 的初始化器。

flax.linen.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#

构建一个初始化器,该初始化器返回充满常量 value 的数组。

参数
  • value – 用于填充初始化器的常量值。

  • dtype – 可选;初始化器的默认数据类型。

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
       [-7., -7., -7.]], dtype=float32)
flax.linen.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

构建 delta 正交核的初始化器。

参数
  • scale – 均匀分布的上限。

  • column_axis – 包含应正交的列的轴。

  • dtype – 权重的默认数据类型。

返回值

一个 delta 正交初始化器。传递给初始化器的形状必须是 3D、4D 或 5D。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32)  
Array([[[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]],

       [[ 0.27858758, -0.7949833 , -0.53887904],
        [ 0.9120717 ,  0.04322892,  0.40774566],
        [-0.30085585, -0.6050892 ,  0.73712474]],

       [[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]]], dtype=float32)
flax.linen.initializers.glorot_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 Glorot 正态初始化器(也称为 Xavier 正态初始化器)。

一个 Glorot 正态初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 1.0mode="fan_avg",以及 distribution="truncated_normal"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.linen.initializers.glorot_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 Glorot 均匀初始化器(也称为 Xavier 均匀初始化器)。

一个 Glorot 均匀初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 1.0mode="fan_avg",以及 distribution="uniform"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.linen.initializers.he_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 He 正态初始化器(也称为 Kaiming 正态初始化器)。

一个 He 正态初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 2.0mode="fan_in",以及 distribution="truncated_normal"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.linen.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 He 均匀初始化器(也称为 Kaiming 均匀初始化器)。

一个 He 均匀初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 2.0mode="fan_in",以及 distribution="uniform"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.linen.initializers.kaiming_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 He 正态初始化器(也称为 Kaiming 正态初始化器)。

一个 He 正态初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 2.0mode="fan_in",以及 distribution="truncated_normal"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.linen.initializers.kaiming_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 He 均匀初始化器(也称为 Kaiming 均匀初始化器)。

一个 He 均匀初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 2.0mode="fan_in",以及 distribution="uniform"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.linen.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 Lecun 正态初始化器。

一个 Lecun 正态初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 1.0mode="fan_in",以及 distribution="truncated_normal"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.46700746,  0.8414632 ,  0.8518669 ],
       [-0.61677957, -0.67402434,  0.09683388]], dtype=float32)
flax.linen.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 Lecun 均匀初始化器。

Lecun 均匀初始化器是 jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 1.0mode="fan_in"distribution="uniform"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.56293887,  0.90433645,  0.9119454 ],
       [-0.71479625, -0.7676109 ,  0.12302713]], dtype=float32)
flax.linen.initializers.normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>)#

构建一个初始化器,返回服从正态分布的实数随机数组。

参数
  • stddev – 可选;分布的标准差。

  • dtype – 可选;初始化器的默认数据类型。

返回值

一个初始化器,返回均值为 0,标准差为 stddev 的正态分布数组。

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.0613258 ,  5.6129413 ,  5.6866574 ],
       [-4.063663  , -4.4520254 ,  0.63115686]], dtype=float32)
flax.linen.initializers.truncated_normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>, lower=-2.0, upper=2.0)#

构建一个初始化器,返回截断正态分布的随机数组。

参数
  • stddev – 可选;未截断分布的标准差。请注意,此函数不应用与 variance scaling 初始化器中相同的标准差校正,用户需要通过 stddev 参数自行应用此校正(如果需要)。

  • dtype – 可选;初始化器的默认数据类型。

  • lower – 表示截断下界的浮点数。在输出乘以标准差之前应用。

  • upper – 表示截断上界的浮点数。在输出乘以标准差之前应用。

返回值

一个初始化器,返回服从截断正态分布的数组,其均值为 0,标准差为 stddev,范围为 \(\rm{lower * stddev} < x < \rm{upper * stddev}\)

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 2.9047365,  5.2338114,  5.29852  ],
       [-3.836303 , -4.192359 ,  0.6022964]], dtype=float32)
flax.linen.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)#

一个初始化器,返回一个全为 1 的常数数组。

忽略 key 参数。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.linen.initializers.ones_init()[source]#

构建一个初始化器,返回一个全为 1 的常数数组。

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import ones_init
>>> ones_initializer = ones_init()
>>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.linen.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

构建一个初始化器,返回均匀分布的正交矩阵。

如果形状不是正方形,则矩阵将具有正交的行或列,具体取决于哪一侧较小。

参数
  • scale – 均匀分布的上限。

  • column_axis – 包含应正交的列的轴。

  • dtype – 权重的默认数据类型。

返回值

一个正交初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
       [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)
flax.linen.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#

构建一个初始化器,返回服从均匀分布的实数随机数组。

参数
  • scale – 可选;随机分布的上界。

  • dtype – 可选;初始化器的默认数据类型。

返回值

一个初始化器,返回值在 [0, scale) 范围内均匀分布的数组。

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[7.298188 , 8.691938 , 8.7230015],
       [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
flax.linen.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

初始化器,根据权重张量的形状调整其尺度。

distribution="truncated_normal"distribution="normal" 时,样本从均值为零、标准差为 \(\sqrt{\frac{scale}{n}}\) 的(截断)正态分布中抽取,其中 n

  • 如果 mode="fan_in",则为权重张量中输入单元的数量;

  • 如果 mode="fan_out",则为输出单元的数量;或

  • 如果 mode="fan_avg",则为输入和输出单元数量的平均值。

此初始化器可以通过 in_axisout_axisbatch_axis 进行配置,以用于一般的卷积或密集层;不在这些参数中的轴被假定为“感受野”(卷积核空间轴)。

distribution="truncated_normal" 时,样本的绝对值在缩放之前将在 2 个标准差处截断。

distribution="uniform" 时,样本从

  • 如果 dtype 是实数,则为均匀区间;或

  • 如果 dtype 是复数,则为均匀圆盘;

中抽取,均值为零,标准差为 \(\sqrt{\frac{scale}{n}}\),其中 n 如上所述。

参数
  • scale – 缩放因子(正浮点数)。

  • mode"fan_in""fan_out""fan_avg" 中的一个。

  • distribution – 要使用的随机分布。 "truncated_normal""normal""uniform" 中的一个。

  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

flax.linen.initializers.xavier_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 Glorot 正态初始化器(也称为 Xavier 正态初始化器)。

一个 Glorot 正态初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 1.0mode="fan_avg",以及 distribution="truncated_normal"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.linen.initializers.xavier_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

构建 Glorot 均匀初始化器(也称为 Xavier 均匀初始化器)。

一个 Glorot 均匀初始化器jax.nn.initializers.variance_scaling() 的一个特例,其中 scale = 1.0mode="fan_avg",以及 distribution="uniform"

参数
  • in_axis – 权重数组中输入维度的轴或轴序列。

  • out_axis – 权重数组中输出维度的轴或轴序列。

  • batch_axis – 应忽略的权重数组中的轴或轴序列。

  • dtype – 权重的 dtype。

返回值

一个初始化器。

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.linen.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#

一个初始化器,返回一个全为 0 的常数数组。

忽略 key 参数。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
flax.linen.initializers.zeros_init()[source]#

构建一个初始化器,返回一个全为 0 的常数数组。

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import zeros_init
>>> zeros_initializer = zeros_init()
>>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)