变换#

模块上的JAX变换。

Jax 函数式变换作用于纯函数。Flax 将这些变换扩展到也作用于具有状态变量和 PRNG 序列的模块。我们将这些扩展版本称为“提升的变换”。

提升的变换可以应用于Module类或以Module实例作为其第一个参数的函数。

flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#

jax.vmap的提升版本。

有关 Jax 中未提升的批量转换,请参阅jax.vmap

vmap可用于向Module添加批处理轴。例如,我们可以创建Dense的具有批处理轴的版本,该版本不共享参数

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': 0},
...     split_rngs={'params': True})

通过使用variable_axes={'params': 0},我们指示参数本身在映射轴上进行映射,因此未在映射轴上共享。因此,我们还拆分了“params”RNG,否则参数将在映射轴上初始化为相同。

类似地,vmap可用于添加具有参数共享的批处理轴

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': None},
...     split_rngs={'params': False})

这里我们使用variable_axes={'params': None}来指示参数变量在映射轴上共享。因此,“params”RNG也必须共享。

参数
  • targetModule或以Module作为其第一个参数的函数。

  • variable_axes – 被提升到批量转换中的变量集合。使用None表示广播集合,或使用整数在轴上进行映射。例如,传递variable_axes={'params': None}将表示参数变量应沿映射轴共享。

  • split_rngs – 拆分的 PRNG 序列对于批处理维度中的每个索引将不同。未拆分的 PRNG 将被广播。

  • in_axes – 指定输入参数的映射(请参阅jax.vmap)。

  • out_axes – 指定返回值的映射(请参阅jax.vmap)。

  • axis_size – 指定批处理轴的大小。仅当无法从输入参数中推导出它时,才需要指定它。

  • axis_name – 为批处理轴指定名称。可与并行约简原语(例如jax.lax.pmeanjax.lax.ppermute等)一起使用。请注意,这仅用于 pmap 和 shard map。对于 SPMD jit,您无需手动同步。只需确保正确注释了轴,XLA:SPMD 将插入必要的集体操作。

  • methods – 如果targetModule,则要对其进行 vmap 处理的Module的方法。

  • spmd_axis_name – 添加到fn中出现的任何 pjit 分片约束的轴名称。另请参阅google/flax

  • metadata_params – 传递到变量树中 AxisMetadata 实例的参数字典。

返回值

target的批量/矢量化版本,具有相同的参数,但在由in_axes指示的位置具有额外的轴,以及相同的返回值,但在由out_axes指示的位置具有额外的轴。

flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None, _split_transpose=False)[source]#

jax.lax.scan的提升版本。

有关 Jax 中未提升的扫描,请参阅jax.lax.scan

为了提高与vmap的一致性,此版本的扫描使用in_axesout_axes来确定哪些参数被扫描以及沿哪个轴扫描。

scan区分循环内部 3 种不同类型的值

  1. scan:在循环中迭代的值。所有扫描值在它们被扫描的轴上的大小必须相同。扫描输出将沿扫描轴堆叠。

  2. carry:携带值在每次循环迭代时更新。在整个循环中,它必须具有相同的形状和数据类型。

  3. broadcast:循环闭包的值。当变量被广播时,它们通常在循环体内部初始化,但独立于循环变量。

target应具有签名(module, carry, *xs) -> (carry, ys),其中xsys是进出循环的扫描值。

示例

>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
...
>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...     ScanLSTM = nn.scan(
...       nn.LSTMCell, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     lstm = ScanLSTM(self.features)
...     input_shape =  x[:, 0].shape
...     carry = lstm.initialize_carry(jax.random.key(0), input_shape)
...     carry, x = lstm(carry, x)
...     return x
...
>>> x = jnp.ones((4, 12, 7))
>>> module = LSTM(features=32)
>>> y, variables = module.init_with_output(jax.random.key(0), x)

请注意,当向 nn.scan 提供函数时,扫描操作会从第三个参数开始,在所有参数上进行,如 in_axes 所指定。前面的示例也可以使用函数式形式编写为

>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...
...     cell = nn.LSTMCell(self.features)
...     def body_fn(cell, carry, x):
...       carry, y = cell(carry, x)
...       return carry, y
...     scan = nn.scan(
...       body_fn, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     input_shape =  x[:, 0].shape
...     carry = cell.initialize_carry(
...       jax.random.key(0), input_shape)
...     carry, x = scan(cell, carry, x)
...     return x
...
>>> module = LSTM(features=32)
>>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))

您还可以使用 scan 来减少 JAX 程序的编译时间,方法是将多个层合并到单个扫描循环中,当您有一系列相同的层需要迭代应用于输入时,您可以这样做。例如

>>> class ResidualMLPBlock(nn.Module):
...   @nn.compact
...   def __call__(self, x, _):
...     h = nn.Dense(features=2)(x)
...     h = nn.relu(h)
...     return x + h, None
...
>>> class ResidualMLP(nn.Module):
...   n_layers: int = 4
...
...   @nn.compact
...   def __call__(self, x):
...     ScanMLP = nn.scan(
...       ResidualMLPBlock, variable_axes={'params': 0},
...       variable_broadcast=False, split_rngs={'params': True},
...       length=self.n_layers)
...     x, _ = ScanMLP()(x, None)
...     return x
...
>>> model = ResidualMLP(n_layers=4)
>>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))

为了减少编译和内存使用,您可以使用 remat_scan(),它除了会检查点扫描循环中的每一层。

参数
  • targetModule或以Module作为其第一个参数的函数。

  • variable_axes – 扫描的变量集合。

  • variable_broadcast – 指定广播的变量集合。广播变量不应该依赖于任何无法从循环中提升的计算。这通常用于在 fn 内部定义共享参数。

  • variable_carry – 指定贯穿循环的变量集合。对这些变量的修改将传递到下一个迭代,并在扫描完成时保留。

  • split_rngs – 分割的 PRNG 序列在每次循环迭代中将不同。如果 split 为 False,则 PRNG 在迭代之间将相同。

  • in_axes – 指定要扫描的参数轴。应该是参数的前缀树。使用 flax.core.broadcast 将整个输入馈送到扫描主体每次迭代。

  • out_axes – 指定要扫描的返回值轴。应该是返回值的前缀树。

  • length – 指定循环迭代次数。只有在无法从扫描参数中推导出时,才需要指定此参数。

  • reverse – 如果为真,则从结束到开始反向扫描。

  • unroll – 在循环的单个迭代中展开多少次扫描迭代(默认值:1)。

  • data_transform – 可选函数,用于转换提升的扫描 body_fn 内部的原始功能核心变量和 rng 组,用于内联 SPMD 注释。

  • metadata_params – 传递到变量树中 AxisMetadata 实例的参数字典。

  • methods – 如果 target 是一个 Module,则要扫描的 Module 的方法。

  • _split_transpose – 一种实验性功能,用于将扫描的转置拆分为扫描和映射,由实验性的 Jax lax.scan() 功能支持。

返回值

具有签名 (module, carry, *xs) -> (carry, ys) 的扫描函数,其中 xsys 是进出循环的扫描值。

flax.linen.jit(target, variables=True, rngs=True, static_argnums=(), static_argnames=(), donate_argnums=(), device=None, backend=None, methods=None)[source]#

jax.jit 的提升版本。

参数
  • targetModule或以Module作为其第一个参数的函数。

  • variables – 提升的变量集合。默认情况下,所有集合都将被提升。

  • rngs – 提升的 PRNG 序列。默认情况下,所有 PRNG 序列都将被提升。

  • static_argnums – 一个整数或整数集合,指定哪些位置参数被视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在跟踪期间)进行常量折叠,因此相应参数的值可以是任何 Python 对象。静态参数应该是可散列的,这意味着实现了 __hash____eq__,并且是不可变的。使用这些常量的不同值调用 jitted 函数将触发重新编译。如果 jitted 函数调用的位置参数少于 static_argnums 指示的参数,则会引发错误。不是数组或其容器的参数必须标记为静态。默认为 ()。

  • static_argnames – 一个可选的字符串或字符串集合,指定哪些命名参数被视为静态(编译时常量)。有关详细信息,请参阅有关 static_argnums 的注释。如果未提供但设置了 static_argnums,则默认为基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • donate_argnums – 指定哪些参数被“捐赠”给计算。如果您在计算完成后不再需要这些参数,则可以安全地捐赠参数。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应该重用捐赠给计算的缓冲区,如果您尝试这样做,JAX 将引发错误。

  • device – 这是一个实验性功能,API 可能会更改。可选,jitted 函数将在其上运行的设备。(可用的设备可以通过 jax.devices() 获取。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用 jax.devices()[0]

  • backend – 表示 XLA 后端的字符串:'cpu''gpu''tpu'

  • methods – 如果 target 是一个 Module,则要 jit 的 Module 的方法。

返回值

目标的包装版本,设置为即时编译。

flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#

jax.checkpoint 的提升版本。

检查点是一种通过在反向传播期间重新计算激活来减少内存使用量的技术。在训练大型模型时,检查点模型的部分内容以权衡内存使用量和额外计算量可能会有所帮助。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CheckpointedMLP(nn.Module):
...   @nn.checkpoint
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(128)(x)
...     x = nn.relu(x)
...     x = nn.Dense(1)(x)
...     return x
...
>>> model = CheckpointedMLP()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))

此函数与 remat 具有相同的别名,就像 jax.remat 一样。

参数
  • target – 一个 Module 或一个以 Module 作为其第一个参数的函数。在计算目标的梯度时,将重新计算中间计算。

  • variables – 提升的变量集合。默认情况下,所有集合都将被提升。

  • rngs – 提升的 PRNG 序列。默认情况下,所有 PRNG 序列都将被提升。

  • concrete – 可选,布尔值,指示 fun 是否可能涉及依赖于值的 Python 控制流(默认值为 False)。对这种控制流的支持是可选的,并且默认情况下是禁用的,因为在与 jax.jit() 的某些边缘情况组合中,它会导致一些额外的计算。

  • prevent_cse – 可选,布尔值,指示是否阻止从微分生成的 HLO 中的通用子表达式消除 (CSE) 优化。此 CSE 预防是有成本的,因为它会破坏其他优化,并且因为它可能会在某些后端(尤其是 GPU)上产生高开销。默认值为 True,因为否则,在 jitpmap 下,CSE 会破坏此装饰器的目的。但在某些情况下,例如在 scan 内部使用时,此 CSE 预防机制是不必要的,在这种情况下,应将 prevent_cse 设置为 False。

  • static_argnums – 可选,整数或整数序列,指示要为跟踪和缓存目的专门化哪些参数值。将参数指定为静态可以在跟踪时避免 ConcretizationTypeErrors,但代价是更多的重新跟踪开销。

  • policy – 实验性检查点策略,请参阅 jax.checkpoint

  • methods – 一个可选的方法名称列表,这些方法将被提升,如果 methods 为 None(默认值),则只有 __call__ 方法将被提升。如果``target`` 是一个函数,则忽略 methods

返回值

target 的包装版本。在计算梯度时,将在反向传递中重新计算中间计算。

flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#

结合 remat 和 scan 以提高内存效率并实现常数时间编译。

remat_scan 允许以模型深度为参考,实现常数编译时间和亚线性内存使用。在付出少量常数开销的情况下。这通常有利于非常深的模型。

示例

>>> import flax.linen as nn

>>> class BigModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10))
...     # 100x dense with O(sqrt(N)) memory for gradient computation
...     return DenseStack(8, name="dense_stack")(x)
参数
  • targetModule或以Module作为其第一个参数的函数。

  • **lengths** – 给定级别上的循环迭代次数。总迭代次数 n = prod(lengths)。每个循环都会重新计算。这样,内存消耗与 n^(1 / d) 成正比,其中 d = len(lengths)。最小的内存消耗需要调整长度,以便在嵌套循环的每个级别消耗相同的内存量。

  • policy – 实验性检查点策略,请参阅 jax.checkpoint

  • variable_broadcast – 指定广播的变量集合。广播变量不应该依赖于任何无法从循环中提升的计算。这通常用于在 fn 内部定义共享参数。

  • variable_carry – 指定贯穿循环的变量集合。对这些变量的修改将传递到下一个迭代,并在扫描完成时保留。

  • **variable_axes** – 要扫描的变量集合。默认为 {True: 0}

  • **split_rngs** – 分割的 PRNG 序列对于每个循环迭代将是不同的。如果 split 为 False,则 PRNG 在迭代之间将相同。默认为 {True: True}

返回值

target 的包装版本,它会重复自身 prod(lengths) 次。

flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#

映射模块内的变量。

map_variables 可用于在应用模块之前和之后转换模块内部的变量。这对于在不修改模块本身的情况下屏蔽模块的权重等用途很有帮助。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CausalDense(nn.Module):
...   '''A dense layer that masks the weights such that the output is
...   causal, i.e. output i only depends on input <= i.
...   '''
...   features: int
...
...   def apply_mask(self, variables):
...     return (jax.tree_util.tree_map(jnp.triu, variables)
...             if not self.is_initializing() else variables)
...
...   def setup(self):
...     # temporary class
...     _CausalDense = nn.map_variables(
...       nn.Dense, 'params', self.apply_mask, init=self.is_initializing())
...     self.dense = _CausalDense(features=self.features, use_bias=False)
...
...   def __call__(self, x):
...     return self.dense(x)
...
>>> module = CausalDense(features=5)
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
参数
  • **target** – 要转换的模块或函数。

  • **mapped_collections** – 要转换的集合。

  • **trans_in_fn** – 在应用模块或函数之前修改变量。

  • **trans_out_fn** – 在应用模块或函数之后修改变量,仅当 initmutable 不为 False 时才应用。

  • **init** – 如果为 True,则在转换之前初始化变量。

  • **mutable** – 如果为 True,则映射的变量集合将是可变的。

  • **rngs** – 添加到转换作用域的 PRNGSequences(默认:全部)。

  • **variables** – 添加到转换作用域的其他变量集合。除了由 target 指定的那些(默认:全部)。

  • **methods** – 如果 target 是一个 Module,则为 Module 的方法来映射变量。

返回值

target 的包装版本,它将映射指定的集合。

flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[source]#

jax.jvp 的提升版本。

有关未提升的雅可比-向量积(前向梯度),请参阅 jax.jvp

请注意,不会为变量返回任何切线。当需要变量切线时,应由 fn 使用 Module.variables 显式返回其值。

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     p = self.param('test', nn.initializers._init(), ())
...     return p * x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     scale = LearnScale()
...     vars_t = jax.tree_util.tree_map(jnp.ones_like,
...                                     scale.variables.get('params', {}))
...     _, out_t = nn.jvp(
...         lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),),
...         variable_tangents={'params': vars_t})
...     return out_t

示例

>>> def learn_scale(scope, x):
...   p = scope.param('scale', nn.initializers.zeros_init(), ())
...   return p * x

>>> def f(scope, x):
...   vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {}))
...   x, out_t = lift.jvp(
...       learn_scale, scope, (x,), (jnp.zeros_like(x),),
...       variable_tangents={'params': vars_t})
...   return out_t
参数
  • **fn** – 要微分的函数。其参数应为数组、标量或数组或标量的标准 Python 容器。它应返回数组、标量或数组或标量的标准 Python 容器。它将接收作用域和原始值作为参数。

  • **mdl** – 将对其变量进行微分的模块。

  • **primals** – 应在其中评估 fun 的雅可比矩阵的原始值。应为参数的元组或列表,其长度应等于 fun 的位置参数的数量。

  • **tangents** – 应在其中评估雅可比-向量积的切线向量。应为切线的元组或列表,与 primals 具有相同的树结构和数组形状。

  • **variable_tangents** – 与作用域具有相同结构的字典或 PyTree 字典。字典中的每个条目都指定变量集合的切线。在 variable_tangents 中未指定集合等效于将零向量作为切线传递。

  • **variables** – 在 fn 中可用但未接收切线的其他变量集合。

  • **rngs** – 在 fn 内部可用的 prngs。

返回值

一个 (primals_out, tangents_out) 对,其中 primals_outfun(*primals),而 tangents_out 是在 primals 处使用 tangents 评估的 function 的雅可比-向量积。tangents_out 值与 primals_out 具有相同的 Python 树结构和形状。

flax.linen.vjp(fn, mdl, *primals, has_aux=False, reduce_axes=(), vjp_variables='params', variables=True, rngs=True, multi_scope=False)[source]#

jax.vjp 的提升版本。

有关未提升的向量-雅可比积(反向梯度),请参阅 jax.vjp

请注意,将为 vjp_variables 指定的集合中的所有变量返回梯度。但是,反向函数仅期望 fn 的返回值的余切线。如果变量也需要余切线,则可以使用 Module.variablesfn 返回它们。

示例

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     p = self.param('scale', nn.initializers.zeros_init(), ())
...     return p * x * y

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
...     params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape))
...     return z, params_grad, x_grad, y_grad
参数
  • **fn** – 要微分的函数。其参数应为数组、标量或数组或标量的标准 Python 容器。它应返回数组、标量或数组或标量的标准 Python 容器。它将接收作用域和原始值作为参数。

  • **mdl** – 将对其变量进行微分的模块。

  • **\*primals** – 应在其中评估 fn 的雅可比矩阵的原始值序列。primals 的长度应等于 fn 的位置参数的数量。每个原始值应为数组、标量或其标准 Python 容器的元组。

  • **has_aux** – 可选,布尔值。指示 fn 是否返回一个对,其中第一个元素被视为要微分的数学函数的输出,第二个元素是辅助数据。默认为 False

  • reduce_axes – 可选,轴名称的元组。如果某个轴在此处列出,并且fn隐式地在此轴上广播了一个值,则反向传播将对相应的梯度执行psum操作。否则,VJP 将在命名轴上按示例进行。例如,如果'batch'是命名的批处理轴,则vjp(f, *args, reduce_axes=('batch',))将创建一个在批处理上求和的 VJP 函数,而vjp(f, *args)将创建一个按示例的 VJP。

  • vjp_variables – vjpfun 将为该过滤器指定的所有变量集合返回余切向量。

  • variables – 其他在fn内部可用但不会接收余切的变量集合。

  • **rngs** – 在 fn 内部可用的 prngs。

  • multi_scope – 对于包含来自外部模块传递的多个作用域的模块,允许返回多个作用域的变量梯度,而不是报错。

返回值

如果has_auxFalse,则返回一个(primals_out, vjpfun)对,其中primals_outfn(*primals)vjpfun是一个函数,它将一个与primals_out形状相同的余切向量映射到一个与primals形状相同的余切向量元组,表示在primals处计算的fn的向量-雅可比乘积。如果has_auxTrue,则返回一个(primals_out, vjpfun, aux)元组,其中auxfn返回的辅助数据。

flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[source]#

jax.custom_vjp的提升版本。

forward_fnbackward_fn共同定义了fn的自定义vjp。如果未计算vjp(反向梯度),则原始fn将运行。

forward_fn接收与fn相同的参数,但预期返回一个元组,其中包含fn(mdl, *args)的输出和传递给backward_fn的残差。

backward_fn接收非微分参数、残差和输出切线。它应该返回一个元组,其中包含变量和输入切线。

请注意,nn.vjp返回的vjp函数可以作为残差传递并用于backward_fn中。在反向传播过程中,作用域不可用。如果在backward_fn中需要模块,则可以获取变量的快照并将其作为残差返回到forward_fn中。

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def f(mdl, x):
...       return mdl(x)
...
...     def fwd(mdl, x):
...       return nn.vjp(f, mdl, x)
...
...     def bwd(vjp_fn, y_t):
...       params_t, *inputs_t = vjp_fn(y_t)
...       params_t = jax.tree_util.tree_map(jnp.sign, params_t)
...       return (params_t, *inputs_t)
...
...     sign_grad = nn.custom_vjp(
...         f, forward_fn=fwd, backward_fn=bwd)
...     return sign_grad(nn.Dense(1), x).reshape(())

>>> x = jnp.ones((2,))
>>> variables = Foo().init(jax.random.key(0), x)
>>> grad = jax.grad(Foo().apply)(variables, x)
参数
  • fn – 要为其定义custom_vjp的函数。

  • forward_fn – 一个与fn具有相同参数的函数,返回一个包含原始输出和将传递给backward_fn的残差的元组。

  • backward_fn – 参数作为(*nondiff_args, residuals, tangents)传递。该函数应返回一个元组,其中包含由grad_vars指定的集合中变量的切线和输入参数(模块和非微分参数除外)。

  • grad_vars – 将为其计算vjp的集合(默认值:“params”)。

  • nondiff_argnums – 不计算vjp的参数。

返回值

一个与fn具有相同签名的函数,并带有自定义vjp。

flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#

jax.lax.while_loop的提升版本。

提升的作用域被传递给cond_fnbody_fn。广播的变量是不可变的。携带变量是可变的,但不能更改形状和数据类型。这也意味着您不能在循环体内部初始化变量。如果需要变量初始化,请考虑在调用while_loop之前手动调用一次body_fn

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class WhileLoopExample(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def cond_fn(mdl, c):
...       return mdl.variables['state']['acc'] < 10
...     def body_fn(mdl, c):
...       acc = mdl.variable('state', 'acc', lambda: jnp.array(0))
...       acc.value += 1
...       y = nn.Dense(c.shape[-1])(c)
...       return y
...     c = x
...     if self.is_mutable_collection('params'):
...       return body_fn(self, c)
...     else:
...       return nn.while_loop(cond_fn, body_fn, self, c,
...                             carry_variables='state')

>>> k = jax.random.key(0)
>>> x = jnp.ones((2, 2))
>>> initial_vars = WhileLoopExample().init(k, x)
>>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state'])
参数
  • cond_fn – 只要循环应该继续,就应该返回True。

  • body_fn – while循环的主体。

  • mdl – 应该被提升到循环中的模块。

  • init – 传递给循环的初始状态

  • carry_variables – 贯穿循环的集合,因此是可变的(默认值:无)。

  • broadcast_variables – 被闭包覆盖的集合,因此是只读的(默认值:所有集合)

  • split_rngs – 分割的 PRNG 序列在每次循环迭代中将不同。如果 split 为 False,则 PRNG 在迭代之间将相同。

返回值

执行while循环后最终的状态。

flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#

jax.lax.cond的提升版本。

true_funfalse_fun返回的值必须具有相同的Pytree结构、形状和数据类型。在分支内部创建或更新的变量也必须具有相同的结构。请注意,当仅在一个分支中创建变量或子模块时,此约束会被违反。因为仅在一个分支中初始化变量会导致参数结构不同。

示例

>>> import flax.linen as nn

>>> class CondExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, pred):
...     self.variable('state', 'true_count', lambda: 0)
...     self.variable('state', 'false_count', lambda: 0)
...     def true_fn(mdl, x):
...       mdl.variable('state', 'true_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def false_fn(mdl, x):
...       mdl.variable('state', 'false_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     return nn.cond(pred, true_fn, false_fn, self, x)
参数
  • pred – 确定评估true_fun还是false_fun。

  • true_fun – 当predTrue时评估的函数。签名为(module, *operands) -> T。

  • false_fun – 当predFalse时评估的函数。签名为(module, *operands) -> T。

  • mdl – 要传递的模块。

  • *operands – 传递给true_funfalse_fun的参数

  • variables – 传递给条件分支的变量集合(默认值:全部)

  • rngs – 传递给条件语句的PRNG序列(默认值:全部)

返回值

已评估分支(true_funfalse_fun)的结果。

flax.linen.switch(index, branches, mdl, *operands, variables=True, rngs=True)[source]#

jax.lax.switch的提升版本。

来自 branches 的返回值必须具有相同的 Pytree 结构、形状和数据类型。在分支内部创建或更新的变量也必须具有相同的结构。请注意,当仅在一个分支中创建变量或子模块时,会违反此约束。因为仅在一个分支中初始化变量会导致参数结构不同。

示例

>>> import flax.linen as nn

>>> class SwitchExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, index):
...     self.variable('state', 'a_count', lambda: 0)
...     self.variable('state', 'b_count', lambda: 0)
...     self.variable('state', 'c_count', lambda: 0)
...     def a_fn(mdl, x):
...       mdl.variable('state', 'a_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def b_fn(mdl, x):
...       mdl.variable('state', 'b_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     def c_fn(mdl, x):
...       mdl.variable('state', 'c_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     return nn.switch(index, [a_fn, b_fn, c_fn], self, x)

如果希望每个分支具有不同的参数结构,则应在调用 switch 之前在初始化时运行所有分支。

>>> class MultiHeadSwitchExample(nn.Module):
...   def setup(self) -> None:
...     self.heads = [
...       nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]),
...       nn.Sequential([nn.Dense(11), nn.Dense(5)]),
...       nn.Dense(5),
...     ]
...
...   @nn.compact
...   def __call__(self, x, index):
...     def head_fn(i):
...       return lambda mdl, x: mdl.heads[i](x)
...     branches = [head_fn(i) for i in range(len(self.heads))]
...
...     # run all branches on init
...     if self.is_mutable_collection('params'):
...       for branch in branches:
...         _ = branch(self, x)
...
...     return nn.switch(index, branches, self, x)
参数
  • index – 整数标量类型,指示要应用哪个分支函数。

  • branches – 基于 index 应用的一系列函数。每个函数的签名为 (module, *operands) -> T。

  • mdl – 要传递的模块。

  • *operands – 传递给分支的参数。

  • variables – 传递给条件分支的变量集合(默认值:全部)

  • rngs – 传递给条件语句的PRNG序列(默认值:全部)

返回值

已评估分支的结果。