转换#

class flax.nnx.Jit(*args, **kwargs)[源代码]#
class flax.nnx.Remat(*args, **kwargs)[源代码]#
class flax.nnx.Scan(*args, **kwargs)[源代码]#
class flax.nnx.Vmap(*args, **kwargs)[源代码]#
flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码]#

jax.grad 的提升版本,可以处理模块/图节点作为参数。

每个图节点的可微分状态由 wrt 过滤器定义,默认设置为 nnx.Param。 在内部,提取图节点的 State,根据 wrt 过滤器进行过滤,并传递到底层的 jax.grad 函数。图节点的梯度类型为 State

示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(2, 3)
  )
})
参数
  • fun – 要微分的函数。其参数在 argnums 指定的位置应该是数组、标量、图节点或标准 Python 容器。在 argnums 指定的位置的参数数组必须是不精确的(即浮点型或复数型)。它应该返回一个标量(包括形状为 () 的数组,但不包括形状为 (1,) 等的数组)。

  • argnums – 可选,整数或整数序列。指定要对其求导的(一个或多个)位置参数(默认为 0)。

  • has_aux – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被认为是要求微分的数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic – 可选,布尔值。指示 fun 是否承诺是全纯的。如果为 True,则输入和输出必须是复数。默认为 False。

  • allow_int – 可选,布尔值。是否允许对整数值输入进行微分。整数输入的梯度将具有一个平凡的向量空间 dtype (float0)。默认为 False。

  • reduce_axes – 可选,轴名称元组。如果此处列出了一个轴,并且 fun 隐式地在该轴上广播一个值,则反向传播将执行相应梯度的 psum。否则,梯度将是每个示例在命名轴上的梯度。例如,如果 'batch' 是命名的批处理轴,grad(f, reduce_axes=('batch',)) 将创建一个计算总梯度的函数,而 grad(f) 将创建一个计算每个示例梯度的函数。

flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[源代码]#

jax.jit 的提升版本,可以处理模块/图节点作为参数。

参数
  • fun

    要进行 JIT 编译的函数。fun 应该是一个纯函数,因为副作用可能只执行一次。

    fun 的参数和返回值应该是数组、标量或其(嵌套)标准 Python 容器(元组/列表/字典)。由 static_argnums 指示的位置参数可以是任何东西,只要它们是可散列的并且定义了相等运算。静态参数包含在编译缓存键中,这就是为什么必须定义散列和相等运算符的原因。

    JAX 保留对 fun 的弱引用,以用作编译缓存键,因此对象 fun 必须是可弱引用的。大多数 Callable 对象已经满足此要求。

  • in_shardings

    fun 的参数结构匹配的 Pytree,所有实际参数都替换为资源分配规范。指定 pytree 前缀(例如,用一个值代替整个子树)也是有效的,在这种情况下,叶子会广播到该子树中的所有值。

    in_shardings 参数是可选的。JAX 将从输入 jax.Array 推断分片,如果无法推断分片,则默认为复制输入。

    有效资源分配规范是
    • Sharding,它将决定如何对值进行分区。

      有了这个,就不需要使用网格上下文管理器。

    • None,将使 JAX 可以自由选择它想要的任何分片。对于 in_shardings,JAX 将其标记为复制,但此行为将来可能会更改。对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

    每个维度的大小必须是分配给它的资源总数的倍数。这类似于 pjit 的 in_shardings。

  • out_shardings

    类似于 in_shardings,但指定函数输出的资源分配。这类似于 pjit 的 out_shardings。

    out_shardings 参数是可选的。如果未指定,jax.jit() 将使用 GSPMD 的分片传播来确定输出的分片应该是什么。

  • static_argnums

    一个可选的整数或整数集合,用于指定将哪些位置参数视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在追踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。

    静态参数应该是可哈希的,这意味着 __hash____eq__ 都已实现,并且是不可变的。使用这些常量的不同值调用 jitted 函数将触发重新编译。不是数组或其容器的参数必须标记为静态。

    如果既没有提供 static_argnums 也没有提供 static_argnames,则不将任何参数视为静态。如果未提供 static_argnums 但提供了 static_argnames,反之亦然,JAX 会使用 inspect.signature(fun) 来查找与 static_argnames 相对应的任何位置参数(反之亦然)。如果同时提供了 static_argnumsstatic_argnames,则不使用 inspect.signature,并且仅将 static_argnumsstatic_argnames 中列出的实际参数视为静态。

  • static_argnames – 一个可选的字符串或字符串集合,用于指定将哪些命名参数视为静态(编译时常量)。有关详细信息,请参阅有关 static_argnums 的注释。如果未提供,但设置了 static_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • donate_argnums

    指定哪些位置参数缓冲区“捐赠”给计算。如果您在计算完成后不再需要参数缓冲区,则可以安全地捐赠它们。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应重复使用捐赠给计算的缓冲区,如果您尝试这样做,JAX 将引发错误。默认情况下,不捐赠任何参数缓冲区。

    如果既没有提供 donate_argnums 也没有提供 donate_argnames,则不捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 会使用 inspect.signature(fun) 来查找与 donate_argnames 相对应的任何位置参数(反之亦然)。如果同时提供了 donate_argnumsdonate_argnames,则不使用 inspect.signature,并且仅将 donate_argnumsdonate_argnames 中列出的实际参数捐赠出去。

    有关缓冲区捐赠的更多详细信息,请参阅FAQ

  • donate_argnames – 一个可选的字符串或字符串集合,用于指定将哪些命名参数捐赠给计算。有关详细信息,请参阅有关 donate_argnums 的注释。如果未提供,但设置了 donate_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • keep_unused – 如果 False(默认值),则 JAX 确定 fun 未使用的参数可能会从生成的已编译 XLA 可执行文件中删除。这些参数不会传输到设备或提供给底层可执行文件。如果 True,则不会修剪未使用的参数。

  • device – 这是一个实验性功能,API 可能会发生变化。可选,jitted 函数将在其上运行的设备。(可以通过 jax.devices() 检索可用设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用 jax.devices()[0]

  • backend – 这是一个实验性功能,API 可能会发生变化。可选,表示 XLA 后端的字符串:'cpu''gpu''tpu'

  • inline – 指定此函数是否应内联到封闭的 jaxprs 中(而不是表示为具有其自身 subjaxpr 的 xla_call 原语的应用程序)。默认为 False。

返回

为即时编译设置的 fun 的包装版本。

flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[source]#
flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[source]#
flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[source]#

jax.vmap 的引用感知版本。

参数
  • f – 要在附加轴上映射的函数。

  • in_axes – 一个整数、None 或值序列,用于指定要映射的输入数组轴(请参阅jax.vmap)。除了整数和 None 之外,StateAxes 可以用于控制图节点(如模块)的矢量化方式,方法是指定应用于图节点子状态的轴,给定一个过滤器

  • out_axes – 一个整数、None 或 pytree,指示映射轴应出现在输出中的位置(请参阅jax.vmap)。

  • axis_name – 可选,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合。

  • axis_size – 可选,一个整数,指示要映射的轴的大小。如果未提供,则从参数推断映射的轴大小。

返回

带有参数的 f 的批处理/矢量化版本,这些参数与 f 的参数相对应,但在 in_axes 指示的位置有额外的数组轴,并且返回值与 f 的返回值相对应,但在 out_axes 指示的位置有额外的数组轴。

示例

>>> from flax import nnx
>>> from jax import random, numpy as jnp
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
...
>>> @nnx.vmap(in_axes=(None, 0), out_axes=0)
... def forward(model, x):
...   return model(x)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)
>>> class LinearEnsemble(nnx.Module):
...   def __init__(self, num, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3)))
...
>>> model = LinearEnsemble(5, rngs=nnx.Rngs(0))
>>> x = jnp.ones((2,))
...
>>> @nnx.vmap(in_axes=(0, None), out_axes=0)
... def forward(model, x):
...   return jnp.dot(x, model.w.value)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)

要控制图节点子状态的矢量化方式,可以将 StateAxes 传递给 in_axesout_axes,指定应用于给定过滤器的每个子状态的轴。以下示例显示了如何在保持不同批次统计数据和 dropout 随机状态的同时在集成成员之间共享参数

>>> class Foo(nnx.Module):
...   def __init__(self):
...     self.a = nnx.Param(jnp.arange(4))
...     self.b = nnx.BatchStat(jnp.arange(4))
...
>>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None})
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=0)
... def mul(foo):
...   return foo.a * foo.b
...
>>> foo = Foo()
>>> y = mul(foo)
>>> y
Array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]], dtype=int32)
flax.nnx.eval_shape(f, *args, **kwargs)[source]#
flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=())[来源]#

jax.custom_vjp 的引用感知版本。

nnx.custom_vjp 接受模块和其他 Flax NNX 对象作为参数。与 JAX 版本的主要区别在于,由于模块遵循引用语义,它们会将输入的 State 更新作为辅助输出传播。这意味着 bwd 函数中的传入梯度将具有 (input_updates_g, out_g) 的形式,其中 input_updates_g 是输入相对于输入的梯度更新状态。输入上的所有模块项都将在 input_updates_g 中有一个关联的 State 项,而所有非模块项将显示为 None。tanget 的形状应与输入的形状相同,其中 State 项替换相应的模块项。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> from flax import nnx
...
>>> class Foo(nnx.Module):
...   def __init__(self, x, y):
...     self.x = nnx.Param(x)
...     self.y = nnx.Param(y)
...
>>> @nnx.custom_vjp
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y
...
>>> def f_fwd(m: Foo):
...   return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, sin_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g['x'].value = cos_x * out_g * m.y
...   m_g['y'].value = sin_x * out_g
...   return (m_g,)
...
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grads = nnx.grad(f)(m)
...
>>> jax.tree.map(jnp.shape, grads)
State({
  'x': VariableState(
    type=Param,
    value=()
  ),
  'y': VariableState(
    type=Param,
    value=()
  )
})

请注意,表示 input_updates_g 上模块项的 State 对象与输出 tanget 中预期的 State 对象具有相同的形状。这意味着您通常可以直接从 input_updates_g 复制它们,并使用其对应的梯度值更新它们。

您可以通过将 DiffState 传递给 nondiff_argnums 来选择模块和其他图节点的可微分子状态(具有切线)。例如,如果您只想微分 Foo 类的 x 属性,您可以执行以下操作

>>> x_attribute = nnx.PathContains('x')
>>> diff_state = nnx.DiffState(0, x_attribute)
...
>>> @nnx.custom_vjp(nondiff_argnums=(diff_state,))
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y  # type: ignore

>>> def f_fwd(m: Foo):
...   y = f(m)
...   res = (jnp.cos(m.x), m)  # type: ignore
...   return y, res
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g.x.value = cos_x * out_g * m.y
...   del m_g['y'] # y is not differentiable
...   return (m_g,)

>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m)
...
>>> jax.tree.map(jnp.shape, grad)
State({
  'x': VariableState(
    type=Param,
    value=()
  )
})

请注意,grad 无法计算没有由 custom_vjp 定义的切线的状态的梯度,在上面的示例中,我们重用相同的 x_attribute 过滤器来保持 custom_vjpgrad 的同步。

参数
  • fun – 可调用的基本函数。

  • nondiff_argnums – 指定不微分的参数索引的整数或 DiffState 对象元组。默认情况下,所有参数都被微分。整数不能用于将模块等图节点标记为不可微分,在这种情况下,请使用 DiffState 对象。DiffState 对象定义可微分的子状态集,与此参数名称所暗示的相反,这样做是为了与 grad 兼容。

flax.nnx.cond(pred, true_fun, false_fun, *operands, **kwargs)[来源]#
flax.nnx.switch(index, branches, *operands)[来源]#
flax.nnx.while_loop(cond_fun, body_fun, init_val)[来源]#

jax.lax.while_loop 的 Flax NNX 转换。

注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在 body_fun 内部更改 init_val 的变量引用结构。

示例

>>> import jax
>>> from flax import nnx
>>> def fwd_fn(input):
...   module, x, count = input
...   return module, module(x), count - 1.0

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> # `module` will be called three times
>>> _, y, _ = nnx.while_loop(
...   lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
参数
  • cond_fun – while 循环的继续条件函数,接受类型为 T 的单个输入并输出布尔值。

  • body_fun – 接受类型为 T 的输入并输出 T 的函数。请注意,T 的数据和模块在输入和输出之间必须具有相同的引用结构。

  • init_valcond_funbody_fun 的初始输入。必须为 T 类型。

flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[来源]#

jax.lax.fori_loop 的 Flax NNX 转换。

注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在 body_fun 内部更改 init_val 的变量引用结构。

示例

>>> import jax
>>> from flax import nnx

>>> def fwd_fn(i, input):
...   m, x = input
...   m.kernel.value = jnp.identity(10) * i
...   return m, m(x)

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
>>> np.testing.assert_array_equal(y, x * 2 * 3)
参数
  • lower – 表示循环索引下限(含)的整数。

  • upper – 表示循环索引上限(不含)的整数。

  • body_fun – 接受类型为 T 的输入并输出 T 的函数。请注意,T 的数据和模块在输入和输出之间必须具有相同的引用结构。

  • init_val – body_fun 的初始输入。必须为 T 类型。

  • unroll – 可选的整数或布尔值,用于确定展开循环的程度。如果提供了整数,则它将确定在循环的单个滚动迭代中运行多少展开的循环迭代。如果提供了布尔值,则它将确定循环是否完全展开(即 unroll=True)或完全保持展开(即 unroll=False)。此参数仅在循环边界为静态已知时适用。

返回

来自最后一次迭代的循环值,类型为 T