转换#
- flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码]#
jax.grad
的提升版本,可以处理模块/图节点作为参数。每个图节点的可微分状态由 wrt 过滤器定义,默认设置为 nnx.Param。 在内部,提取图节点的
State
,根据 wrt 过滤器进行过滤,并传递到底层的jax.grad
函数。图节点的梯度类型为State
。示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) })
- 参数
fun – 要微分的函数。其参数在
argnums
指定的位置应该是数组、标量、图节点或标准 Python 容器。在argnums
指定的位置的参数数组必须是不精确的(即浮点型或复数型)。它应该返回一个标量(包括形状为()
的数组,但不包括形状为(1,)
等的数组)。argnums – 可选,整数或整数序列。指定要对其求导的(一个或多个)位置参数(默认为 0)。
has_aux – 可选,布尔值。指示
fun
是否返回一个对,其中第一个元素被认为是要求微分的数学函数的输出,第二个元素是辅助数据。默认为 False。holomorphic – 可选,布尔值。指示
fun
是否承诺是全纯的。如果为 True,则输入和输出必须是复数。默认为 False。allow_int – 可选,布尔值。是否允许对整数值输入进行微分。整数输入的梯度将具有一个平凡的向量空间 dtype (float0)。默认为 False。
reduce_axes – 可选,轴名称元组。如果此处列出了一个轴,并且
fun
隐式地在该轴上广播一个值,则反向传播将执行相应梯度的psum
。否则,梯度将是每个示例在命名轴上的梯度。例如,如果'batch'
是命名的批处理轴,grad(f, reduce_axes=('batch',))
将创建一个计算总梯度的函数,而grad(f)
将创建一个计算每个示例梯度的函数。
- flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[源代码]#
jax.jit
的提升版本,可以处理模块/图节点作为参数。- 参数
fun –
要进行 JIT 编译的函数。
fun
应该是一个纯函数,因为副作用可能只执行一次。fun
的参数和返回值应该是数组、标量或其(嵌套)标准 Python 容器(元组/列表/字典)。由static_argnums
指示的位置参数可以是任何东西,只要它们是可散列的并且定义了相等运算。静态参数包含在编译缓存键中,这就是为什么必须定义散列和相等运算符的原因。JAX 保留对
fun
的弱引用,以用作编译缓存键,因此对象fun
必须是可弱引用的。大多数Callable
对象已经满足此要求。in_shardings –
与
fun
的参数结构匹配的 Pytree,所有实际参数都替换为资源分配规范。指定 pytree 前缀(例如,用一个值代替整个子树)也是有效的,在这种情况下,叶子会广播到该子树中的所有值。in_shardings
参数是可选的。JAX 将从输入jax.Array
推断分片,如果无法推断分片,则默认为复制输入。- 有效资源分配规范是
Sharding
,它将决定如何对值进行分区。有了这个,就不需要使用网格上下文管理器。
None
,将使 JAX 可以自由选择它想要的任何分片。对于 in_shardings,JAX 将其标记为复制,但此行为将来可能会更改。对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。
每个维度的大小必须是分配给它的资源总数的倍数。这类似于 pjit 的 in_shardings。
out_shardings –
类似于
in_shardings
,但指定函数输出的资源分配。这类似于 pjit 的 out_shardings。out_shardings
参数是可选的。如果未指定,jax.jit()
将使用 GSPMD 的分片传播来确定输出的分片应该是什么。static_argnums –
一个可选的整数或整数集合,用于指定将哪些位置参数视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在追踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。
静态参数应该是可哈希的,这意味着
__hash__
和__eq__
都已实现,并且是不可变的。使用这些常量的不同值调用 jitted 函数将触发重新编译。不是数组或其容器的参数必须标记为静态。如果既没有提供
static_argnums
也没有提供static_argnames
,则不将任何参数视为静态。如果未提供static_argnums
但提供了static_argnames
,反之亦然,JAX 会使用inspect.signature(fun)
来查找与static_argnames
相对应的任何位置参数(反之亦然)。如果同时提供了static_argnums
和static_argnames
,则不使用inspect.signature
,并且仅将static_argnums
或static_argnames
中列出的实际参数视为静态。static_argnames – 一个可选的字符串或字符串集合,用于指定将哪些命名参数视为静态(编译时常量)。有关详细信息,请参阅有关
static_argnums
的注释。如果未提供,但设置了static_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。donate_argnums –
指定哪些位置参数缓冲区“捐赠”给计算。如果您在计算完成后不再需要参数缓冲区,则可以安全地捐赠它们。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应重复使用捐赠给计算的缓冲区,如果您尝试这样做,JAX 将引发错误。默认情况下,不捐赠任何参数缓冲区。
如果既没有提供
donate_argnums
也没有提供donate_argnames
,则不捐赠任何参数。如果未提供donate_argnums
但提供了donate_argnames
,反之亦然,JAX 会使用inspect.signature(fun)
来查找与donate_argnames
相对应的任何位置参数(反之亦然)。如果同时提供了donate_argnums
和donate_argnames
,则不使用inspect.signature
,并且仅将donate_argnums
或donate_argnames
中列出的实际参数捐赠出去。有关缓冲区捐赠的更多详细信息,请参阅FAQ。
donate_argnames – 一个可选的字符串或字符串集合,用于指定将哪些命名参数捐赠给计算。有关详细信息,请参阅有关
donate_argnums
的注释。如果未提供,但设置了donate_argnums
,则默认值基于调用inspect.signature(fun)
来查找相应的命名参数。keep_unused – 如果 False(默认值),则 JAX 确定 fun 未使用的参数可能会从生成的已编译 XLA 可执行文件中删除。这些参数不会传输到设备或提供给底层可执行文件。如果 True,则不会修剪未使用的参数。
device – 这是一个实验性功能,API 可能会发生变化。可选,jitted 函数将在其上运行的设备。(可以通过
jax.devices()
检索可用设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用jax.devices()[0]
。backend – 这是一个实验性功能,API 可能会发生变化。可选,表示 XLA 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。inline – 指定此函数是否应内联到封闭的 jaxprs 中(而不是表示为具有其自身 subjaxpr 的 xla_call 原语的应用程序)。默认为 False。
- 返回
为即时编译设置的
fun
的包装版本。
- flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[source]#
- flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[source]#
- flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
- flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[source]#
jax.vmap 的引用感知版本。
- 参数
f – 要在附加轴上映射的函数。
in_axes – 一个整数、None 或值序列,用于指定要映射的输入数组轴(请参阅jax.vmap)。除了整数和 None 之外,
StateAxes
可以用于控制图节点(如模块)的矢量化方式,方法是指定应用于图节点子状态的轴,给定一个过滤器。out_axes – 一个整数、None 或 pytree,指示映射轴应出现在输出中的位置(请参阅jax.vmap)。
axis_name – 可选,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合。
axis_size – 可选,一个整数,指示要映射的轴的大小。如果未提供,则从参数推断映射的轴大小。
- 返回
带有参数的
f
的批处理/矢量化版本,这些参数与f
的参数相对应,但在in_axes
指示的位置有额外的数组轴,并且返回值与f
的返回值相对应,但在out_axes
指示的位置有额外的数组轴。
示例
>>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3)
>>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return jnp.dot(x, model.w.value) ... >>> y = forward(model, x) >>> y.shape (5, 3)
要控制图节点子状态的矢量化方式,可以将
StateAxes
传递给in_axes
和out_axes
,指定应用于给定过滤器的每个子状态的轴。以下示例显示了如何在保持不同批次统计数据和 dropout 随机状态的同时在集成成员之间共享参数>>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32)
- flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=())[来源]#
jax.custom_vjp 的引用感知版本。
nnx.custom_vjp
接受模块和其他 Flax NNX 对象作为参数。与 JAX 版本的主要区别在于,由于模块遵循引用语义,它们会将输入的 State 更新作为辅助输出传播。这意味着bwd
函数中的传入梯度将具有(input_updates_g, out_g)
的形式,其中input_updates_g
是输入相对于输入的梯度更新状态。输入上的所有模块项都将在input_updates_g
中有一个关联的State
项,而所有非模块项将显示为 None。tanget 的形状应与输入的形状相同,其中State
项替换相应的模块项。示例
>>> import jax >>> import jax.numpy as jnp >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, x, y): ... self.x = nnx.Param(x) ... self.y = nnx.Param(y) ... >>> @nnx.custom_vjp ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, sin_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g['x'].value = cos_x * out_g * m.y ... m_g['y'].value = sin_x * out_g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grads = nnx.grad(f)(m) ... >>> jax.tree.map(jnp.shape, grads) State({ 'x': VariableState( type=Param, value=() ), 'y': VariableState( type=Param, value=() ) })
请注意,表示
input_updates_g
上模块项的 State 对象与输出 tanget 中预期的 State 对象具有相同的形状。这意味着您通常可以直接从input_updates_g
复制它们,并使用其对应的梯度值更新它们。您可以通过将
DiffState
传递给nondiff_argnums
来选择模块和其他图节点的可微分子状态(具有切线)。例如,如果您只想微分Foo
类的x
属性,您可以执行以下操作>>> x_attribute = nnx.PathContains('x') >>> diff_state = nnx.DiffState(0, x_attribute) ... >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y # type: ignore >>> def f_fwd(m: Foo): ... y = f(m) ... res = (jnp.cos(m.x), m) # type: ignore ... return y, res ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g.x.value = cos_x * out_g * m.y ... del m_g['y'] # y is not differentiable ... return (m_g,) >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) ... >>> jax.tree.map(jnp.shape, grad) State({ 'x': VariableState( type=Param, value=() ) })
请注意,
grad
无法计算没有由custom_vjp
定义的切线的状态的梯度,在上面的示例中,我们重用相同的x_attribute
过滤器来保持custom_vjp
和grad
的同步。- 参数
fun – 可调用的基本函数。
nondiff_argnums – 指定不微分的参数索引的整数或 DiffState 对象元组。默认情况下,所有参数都被微分。整数不能用于将模块等图节点标记为不可微分,在这种情况下,请使用 DiffState 对象。DiffState 对象定义可微分的子状态集,与此参数名称所暗示的相反,这样做是为了与
grad
兼容。
- flax.nnx.while_loop(cond_fun, body_fun, init_val)[来源]#
jax.lax.while_loop 的 Flax NNX 转换。
注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在
body_fun
内部更改init_val
的变量引用结构。示例
>>> import jax >>> from flax import nnx >>> def fwd_fn(input): ... module, x, count = input ... return module, module(x), count - 1.0 >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> # `module` will be called three times >>> _, y, _ = nnx.while_loop( ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
- 参数
cond_fun – while 循环的继续条件函数,接受类型为
T
的单个输入并输出布尔值。body_fun – 接受类型为
T
的输入并输出T
的函数。请注意,T
的数据和模块在输入和输出之间必须具有相同的引用结构。init_val –
cond_fun
和body_fun
的初始输入。必须为T
类型。
- flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[来源]#
jax.lax.fori_loop 的 Flax NNX 转换。
注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在 body_fun 内部更改 init_val 的变量引用结构。
示例
>>> import jax >>> from flax import nnx >>> def fwd_fn(i, input): ... m, x = input ... m.kernel.value = jnp.identity(10) * i ... return m, m(x) >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) >>> np.testing.assert_array_equal(y, x * 2 * 3)
- 参数
lower – 表示循环索引下限(含)的整数。
upper – 表示循环索引上限(不含)的整数。
body_fun – 接受类型为
T
的输入并输出T
的函数。请注意,T
的数据和模块在输入和输出之间必须具有相同的引用结构。init_val – body_fun 的初始输入。必须为
T
类型。unroll – 可选的整数或布尔值,用于确定展开循环的程度。如果提供了整数,则它将确定在循环的单个滚动迭代中运行多少展开的循环迭代。如果提供了布尔值,则它将确定循环是否完全展开(即
unroll=True
)或完全保持展开(即unroll=False
)。此参数仅在循环边界为静态已知时适用。
- 返回
来自最后一次迭代的循环值,类型为
T
。