循环#
Flax 的 RNN 模块。
- class flax.nnx.nn.recurrent.LSTMCell(*args, **kwargs)[source]#
LSTM 单元。
该单元的数学定义如下
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]其中 x 是输入,h 是上一个时间步的输出,c 是记忆。
- __call__(carry, inputs)[source]#
长短期记忆 (LSTM) 单元。
- 参数
carry – LSTM 单元的隐藏状态,使用
LSTMCell.initialize_carry
初始化。inputs – 具有当前时间步输入的 ndarray。除了最后一个维度外的所有维度都被视为批量维度。
- 返回
包含新 carry 和输出的元组。
- initialize_carry(input_shape, rngs=None)[source]#
初始化 RNN 单元的 carry。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 单元的 carry。
- class flax.nnx.nn.recurrent.OptimizedLSTMCell(*args, **kwargs)[source]#
更高效的 LSTM 单元,在 matmul 之前连接状态分量。
这些参数与
LSTMCell
兼容。请注意,只要隐藏大小约为 <= 2048 个单元,此单元通常比LSTMCell
快。该单元的数学定义与
LSTMCell
相同,如下所示\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]其中 x 是输入,h 是上一个时间步的输出,c 是记忆。
- gate_fn#
用于门的激活函数(默认:sigmoid)。
- activation_fn#
用于输出和记忆更新的激活函数(默认:tanh)。
- kernel_init#
用于转换输入的内核的初始化器函数(默认:lecun_normal)。
- recurrent_kernel_init#
用于转换隐藏状态的内核的初始化器函数(默认:initializers.orthogonal())。
- bias_init#
用于偏置参数的初始化器(默认:initializers.zeros_init())。
- dtype#
计算的 dtype(默认:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- __call__(carry, inputs)[source]#
优化的长短期记忆 (LSTM) 单元。
- 参数
carry – LSTM 单元的隐藏状态,使用
LSTMCell.initialize_carry
初始化。inputs – 具有当前时间步输入的 ndarray。除了最后一个维度外的所有维度都被视为批量维度。
- 返回
包含新 carry 和输出的元组。
- initialize_carry(input_shape, rngs=None)[source]#
初始化 RNN 单元的 carry。
- 参数
rngs – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 单元的 carry。
- class flax.nnx.nn.recurrent.SimpleCell(*args, **kwargs)[source]#
简单单元。
该单元的数学定义如下
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]其中 x 是输入,h 是上一个时间步的输出。
如果 residual 为 True,
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]- __call__(carry, inputs)[source]#
运行 RNN 单元。
- 参数
carry – RNN 单元的隐藏状态。
inputs – 具有当前时间步输入的 ndarray。除了最后一个维度外的所有维度都被视为批量维度。
- 返回
包含新 carry 和输出的元组。
- initialize_carry(input_shape, rngs=None)[source]#
初始化 RNN 单元的 carry。
- 参数
rng – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 单元的 carry。
- class flax.nnx.nn.recurrent.GRUCell(*args, **kwargs)[source]#
GRU 单元。
该单元的数学定义如下
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]其中 x 是输入,h 是上一个时间步的输出。
- in_features#
输入特征的数量。
输出特征的数量。
- gate_fn#
用于门的激活函数(默认:sigmoid)。
- activation_fn#
用于输出和记忆更新的激活函数(默认:tanh)。
- kernel_init#
用于转换输入的内核的初始化器函数(默认:lecun_normal)。
- recurrent_kernel_init#
用于转换隐藏状态的内核的初始化器函数(默认:initializers.orthogonal())。
- bias_init#
用于偏置参数的初始化器(默认:initializers.zeros_init())。
- dtype#
计算的数据类型 (默认值: None)。
- param_dtype#
传递给参数初始化器的 dtype(默认:float32)。
- __call__(carry, inputs)[源代码]#
门控循环单元 (GRU) 单元。
- 参数
carry – GRU 单元的隐藏状态,使用
GRUCell.initialize_carry
初始化。inputs – 具有当前时间步输入的 ndarray。除了最后一个维度外的所有维度都被视为批量维度。
- 返回
包含新 carry 和输出的元组。
- initialize_carry(input_shape, rngs=None)[源代码]#
初始化 RNN 单元的 carry。
- 参数
rngs – 传递给 init_fn 的随机数生成器。
input_shape – 提供单元输入形状的元组。
- 返回
给定 RNN 单元的已初始化 carry。
方法
initialize_carry
(input_shape[, rngs])初始化 RNN 单元的 carry。
- class flax.nnx.nn.recurrent.RNN(*args, **kwargs)[源代码]#
RNN
模块接收任何RNNCellBase
实例,并使用flax.nnx.scan()
将其应用于序列。使用
flax.nnx.scan()
。- __call__(inputs, *, initial_carry=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None, rngs=None)[源代码]#
将 self 作为函数调用。
方法
- class flax.nnx.nn.recurrent.Bidirectional(*args, **kwargs)[源代码]#
在两个方向上处理输入并合并结果。
使用示例
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp >>> # Define forward and backward RNNs >>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> # Create Bidirectional layer >>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) >>> # Input data >>> x = jnp.ones((2, 3, 3)) >>> # Apply the layer >>> out = layer(x) >>> print(out.shape) (2, 3, 8)
- __call__(inputs, *, initial_carry=None, rngs=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[源代码]#
将 self 作为函数调用。
方法
- flax.nnx.nn.recurrent.flip_sequences(inputs, seq_lengths, num_batch_dims, time_major)[源代码]#
沿时间轴翻转输入序列。
此函数可用于为双向 LSTM 的反向准备输入。它解决了以下问题:当天真地翻转存储在矩阵中的多个填充序列时,对于那些被填充的序列,第一个元素将是填充值。此函数将填充保持在末尾,同时翻转其余元素。
示例
>>> from flax.nnx.nn.recurrent import flip_sequences >>> from jax import numpy as jnp >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) >>> lengths = jnp.array([1, 2, 3]) >>> flip_sequences(inputs, lengths, 1, False) Array([[1, 0, 0], [3, 2, 0], [6, 5, 4]], dtype=int32)
- 参数
inputs – 输入 ID 数组 <int>[batch_size, seq_length]。
lengths – 每个序列的长度 <int>[batch_size]。
- 返回
具有翻转输入的 ndarray。