激活函数#
激活函数。
- class flax.linen.activation.PReLU(param_dtype=<class 'jax.numpy.float32'>, negative_slope_init=0.01, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
参数化线性整流单元 (PReLU) 激活函数。
请注意,PReLU 是一个 Flax 层,而不是一个简单的激活函数,因此它需要在调用之前初始化。
- 示例用法:
>>> import flax.linen as nn
>>> class MLP(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(2)(x) ... x = nn.PReLU()(x) # initialized ... return x
- param_dtype#
传递给参数初始化器的 dtype(默认值:float32)。
- 类型
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- negative_slope_init#
初始化负斜率的值(默认值 0.01)。
- 类型
float
- flax.linen.activation.celu(x, alpha=1.0)[source]#
连续可微指数线性单元激活。
计算元素级函数
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]有关更多信息,请参见 连续可微指数线性单元。
- 参数
x – 输入数组
alpha – 数组或标量(默认值:1.0)
- 返回值
一个数组。
- flax.linen.activation.elu(x, alpha=1.0)[source]#
指数线性单元激活函数。
计算元素级函数
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- 参数
x – 输入数组
alpha – alpha 值的标量或数组(默认值:1.0)
- 返回值
一个数组。
另请参见
- flax.linen.activation.gelu(x, approximate=True)[source]#
高斯误差线性单元激活函数。
如果
approximate=False
,则计算元素级函数\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]如果
approximate=True
,则使用 GELU 的近似公式\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]有关更多信息,请参见 高斯误差线性单元 (GELUs),第 2 节。
- 参数
x – 输入数组
approximate – 是否使用近似公式或精确公式。
- flax.linen.activation.glu(x, axis=-1)[source]#
门控线性单元激活函数。
计算函数
\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]其中数组沿
axis
分成两部分。axis
维度的尺寸必须可被 2 整除。- 参数
x – 输入数组
axis – 应沿其计算分割的轴(默认值:-1)
- 返回值
一个数组。
另请参见
- flax.linen.activation.hard_sigmoid(x)[source]#
硬 Sigmoid 激活函数。
计算元素级函数
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
relu6()
- flax.linen.activation.hard_silu(x)[source]#
硬 SiLU(swish)激活函数
计算元素级函数
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]两者都是
hard_silu()
和hard_swish()
是同一个函数的别名。- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
- flax.linen.activation.hard_swish(x)#
硬 SiLU(swish)激活函数
计算元素级函数
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]两者都是
hard_silu()
和hard_swish()
是同一个函数的别名。- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
- flax.linen.activation.hard_tanh(x)[source]#
硬 \(\mathrm{tanh}\) 激活函数。
计算元素级函数
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]- 参数
x – 输入数组
- 返回值
一个数组。
- flax.linen.activation.leaky_relu(x, negative_slope=0.01)[source]#
Leaky 线性整流单元激活函数。
计算元素级函数
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]其中 \(\alpha\) =
negative_slope
。- 参数
x – 输入数组
negative_slope – 指定负斜率的数组或标量(默认值:0.01)
- 返回值
一个数组。
另请参见
- flax.linen.activation.log_sigmoid(x)[source]#
对数 sigmoid 激活函数。
计算元素级函数
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
- flax.linen.activation.log_softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[source]#
对数 Softmax 函数。
计算
softmax
函数的对数,它将元素重新缩放到范围 \([-\infty, 0)\)。\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- 参数
x – 输入数组
axis – 计算
log_softmax
的轴或轴。可以是整数或整数元组。where – 包含在
log_softmax
中的元素。
- 返回值
一个数组。
注意
如果任何输入值为
+inf
,则结果将全部为NaN
:这反映了这样一个事实,即在浮点数学中,inf / inf
没有明确的定义。另请参见
- flax.linen.activation.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None)[source]#
对数求和指数约简。
JAX 实现的
scipy.special.logsumexp()
。\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]其中 \(j\) 索引跨越一个或多个要约简的维度。
- 参数
a – 输入数组
axis – 要约简的轴或轴。可以是
None
、整数或整数元组。b – \(\mathrm{exp}(a)\) 的缩放因子。必须可广播到 a 的形状。
keepdims – 如果为
True
,则约简的轴作为大小为 1 的维度保留在输出中。return_sign – 如果为
True
,则输出将是一个(result, sign)
对,其中sign
是总和的符号,而result
包含其绝对值的对数。如果为False
,则仅返回result
,如果总和为负数,它将包含 NaN 值。where – 包含在约简中的元素。
- 返回值
可以是数组
result
或数组对(result, sign)
,具体取决于return_sign
参数的值。
- flax.linen.activation.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#
对给定的索引进行独热编码。
输入
x
中的每个索引都被编码为一个长度为num_classes
的零向量,其中位于index
处的元素被设置为 1>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
位于范围 [0, num_classes) 之外的索引将被编码为零
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- 参数
x – 索引张量。
num_classes – 独热维度中的类数。
dtype – 可选,返回值的浮点类型(默认值为
jnp.float_
)。axis – 计算函数的轴或轴。
- flax.linen.activation.relu(x)[source]#
修正线性单元激活函数。
计算元素级函数
\[\mathrm{relu}(x) = \max(x, 0)\]但在微分过程中,我们采用
\[\nabla \mathrm{relu}(0) = 0\]有关更多信息,请参阅 ReLU’(0) 对反向传播的数值影响。
- 参数
x – 输入数组
- 返回值
一个数组。
示例
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
另请参见
relu6()
- flax.linen.activation.selu(x)[source]#
缩放指数线性单元激活。
计算元素级函数
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]其中 \(\lambda = 1.0507009873554804934193349852946\) 和 \(\alpha = 1.6732632423543772848170429916717\).
有关更多信息,请参阅 自归一化神经网络。
- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
- flax.linen.activation.sigmoid(x)[source]#
Sigmoid 激活函数。
计算元素级函数
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
- flax.linen.activation.silu(x)[source]#
SiLU(又名 swish)激活函数。
计算元素级函数
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
- flax.linen.activation.soft_sign(x)[source]#
Soft-sign 激活函数。
计算元素级函数
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]- 参数
x – 输入数组
- flax.linen.activation.softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[source]#
Softmax 函数。
计算将元素重新缩放到范围 \([0, 1]\) 的函数,使得沿
axis
的元素之和为 \(1\)。\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- 参数
x – 输入数组
axis – 计算 softmax 的轴或轴。跨这些维度求和的 softmax 输出应之和为 \(1\)。可以是整数或整数元组。
where – 包含在
softmax
中的元素。
- 返回值
一个数组。
注意
如果任何输入值为
+inf
,则结果将全部为NaN
:这反映了这样一个事实,即在浮点数学中,inf / inf
没有明确的定义。另请参见
- flax.linen.activation.softplus(x)[source]#
Softplus 激活函数。
计算元素级函数
\[\mathrm{softplus}(x) = \log(1 + e^x)\]- 参数
x – 输入数组
- flax.linen.activation.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#
通过减去
mean
并除以 \(\sqrt{\mathrm{variance}}\) 来标准化数组。
- flax.linen.activation.swish(x)#
SiLU(又名 swish)激活函数。
计算元素级函数
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- 参数
x – 输入数组
- 返回值
一个数组。
另请参见
- flax.linen.activation.tanh(x, /)[source]#
按元素计算双曲正切。
LAX 后端实现
numpy.tanh()
。原始文档字符串如下。
等效于
np.sinh(x)/np.cosh(x)
或-1j * np.tan(1j*x)
。- 参数
x (array_like) – 输入数组。
- 返回值
y – 对应的双曲正切值。如果 x 是标量,则这是一个标量。
- 返回类型
ndarray
参考文献
- 1
M. Abramowitz 和 I. A. Stegun,《数学函数手册》。纽约州纽约:Dover,1972 年,第 83 页。 https://personal.math.ubc.ca/~cbm/aands/page_83.htm
- 2
维基百科,“双曲函数”,https://en.wikipedia.org/wiki/Hyperbolic_function