#

flax.nnx.split(node, *filters)[源代码]#

将图节点拆分为一个 GraphDef 和一个或多个 State。`State` 是一个从字符串或整数到 Variables、数组或嵌套 State 的 `Mapping`。GraphDef 包含重建 Module 图所需的所有静态信息,它类似于 JAX 的 PyTreeDefsplit()merge() 结合使用,可以在图的有状态和无状态表示之间无缝切换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> jax.tree.map(jnp.shape, params)
State({
  'batch_norm': {
    'bias': VariableState(
      type=Param,
      value=(2,)
    ),
    'scale': VariableState(
      type=Param,
      value=(2,)
    )
  },
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})
>>> jax.tree.map(jnp.shape, batch_stats)
State({
  'batch_norm': {
    'mean': VariableState(
      type=BatchStat,
      value=(2,)
    ),
    'var': VariableState(
      type=BatchStat,
      value=(2,)
    )
  }
})

split()merge() 主要用于直接与 JAX 转换交互,有关更多信息,请参阅 函数式 API

参数
  • node – 要拆分的图节点。

  • *filters – 一些可选的过滤器,用于将状态分组到互斥的子状态中。

返回

GraphDef 和一个或多个 States,其数量等于传递的过滤器数量。如果未传递过滤器,则返回单个 State

flax.nnx.merge(graphdef, state, /, *states)[源代码]#

flax.nnx.split() 的逆操作。

nnx.merge 接受一个 flax.nnx.GraphDef 和一个或多个 flax.nnx.State,并创建一个与原始节点结构相同的新节点。

回顾:flax.nnx.split() 用于通过以下方式表示 flax.nnx.Module:1)一个静态的 nnx.GraphDef,它捕获其 Pythonic 静态信息;2)一个或多个 flax.nnx.Variable nnx.State (s),它以 JAX pytrees 的形式捕获其 jax.Array

nnx.mergennx.split 结合使用,可以在图的有状态和无状态表示之间无缝切换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> new_node = nnx.merge(graphdef, params, batch_stats)
>>> assert isinstance(new_node, Foo)
>>> assert isinstance(new_node.batch_norm, nnx.BatchNorm)
>>> assert isinstance(new_node.linear, nnx.Linear)

nnx.splitnnx.merge 主要用于直接与 JAX 转换交互(有关更多信息,请参阅 函数式 API)。

参数
返回

合并后的 flax.nnx.Module

flax.nnx.update(node, state, /, *states)[源代码]#

使用新的状态(们)就地更新给定的图节点。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))

>>> def loss_fn(model, x, y):
...   return jnp.mean((y - model(x))**2)
>>> prev_loss = loss_fn(model, x, y)

>>> grads = nnx.grad(loss_fn)(model, x, y)
>>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads)
>>> nnx.update(model, new_state)
>>> assert loss_fn(model, x, y) < prev_loss
参数
  • node – 要更新的图节点。

  • state – 一个 State 对象。

  • *states – 其他 State 对象。

flax.nnx.pop(node, *filters)[源代码]#

从图节点中弹出一种或多种 Variable 类型。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')
>>> y = model(x)
>>> assert hasattr(model, 'i')

>>> intermediates = nnx.pop(model, nnx.Intermediate)
>>> assert intermediates['i'].value[0].shape == (1, 3)
>>> assert not hasattr(model, 'i')
参数
  • node – 一个图节点对象。

  • *filters – 要按其筛选的一个或多个 Variable 对象。

返回

弹出的 State 包含被筛选的 Variable 对象。

flax.nnx.state(node, *filters)[源代码]#

类似于 split(),但仅返回由过滤器指示的 State

用法示例

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batch_norm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> # get the learnable parameters from the batch norm and linear layer
>>> params = nnx.state(model, nnx.Param)
>>> # get the batch statistics from the batch norm layer
>>> batch_stats = nnx.state(model, nnx.BatchStat)
>>> # get them separately
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> # get them together
>>> state = nnx.state(model)
参数
  • node – 一个图节点对象。

  • *filters – 要按其筛选的一个或多个 Variable 对象。

返回

一个或多个 State 映射。

flax.nnx.variables(node, *filters)[源代码]#

类似于 state(),但返回当前的 Variable 对象,而不是新的 VariableState 实例。

示例

>>> from flax import nnx
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> params = nnx.variables(model, nnx.Param)
...
>>> assert params['kernel'] is model.kernel
>>> assert params['bias'] is model.bias
参数
  • node – 一个图节点对象。

  • *filters – 要按其筛选的一个或多个 Variable 对象。

返回

一个或多个包含 Variable 对象的 State 映射。

flax.nnx.graph()#
flax.nnx.graphdef(node, /)[源代码]#

获取给定图节点的 GraphDef

用法示例

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> graphdef, _ = nnx.split(model)
>>> assert graphdef == nnx.graphdef(model)
参数

node – 一个图节点对象。

返回

Module 对象的 GraphDef

flax.nnx.iter_graph(node, /)[源代码]#

迭代给定图节点的所有嵌套节点和叶子,包括当前节点。

iter_graph 创建一个生成器,它产生路径和值对,其中路径是一个字符串或整数的元组,表示从根到该值的路径。重复的节点只访问一次。叶子包括静态值。

示例:
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Linear(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.din, self.dout = din, dout
...     self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...
>>> module = Linear(3, 4, rngs=nnx.Rngs(0))
>>> graph = [module, module]
...
>>> for path, value in nnx.iter_graph(graph):
...   print(path, type(value).__name__)
...
(0, 'b') Param
(0, 'din') int
(0, 'dout') int
(0, 'w') Param
(0,) Linear
() list
flax.nnx.clone(node)[源代码]#

创建给定图节点的深层副本。

用法示例

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> cloned_model = nnx.clone(model)
>>> model.bias.value += 1
>>> assert (model.bias.value != cloned_model.bias.value).all()
参数

node – 一个图节点对象。

返回

Module 对象的深层副本。

flax.nnx.call(graphdef_state, /)[源代码]#

调用由 (GraphDef, State) 对定义的底层图节点的方法。

call 接受一个 (GraphDef, State) 对,并创建一个代理对象,该对象可用于调用底层图节点的方法。当调用一个方法时,将返回输出,以及一个新的 (GraphDef, State) 对,该对表示图节点的更新状态。call 等价于 merge() > method > split`(),但在纯 JAX 函数中使用更方便。

示例

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> linear = StatefulLinear(3, 2, nnx.Rngs(0))
>>> linear_state = nnx.split(linear)
...
>>> @jax.jit
... def forward(x, linear_state):
...   y, linear_state = nnx.call(linear_state)(x)
...   return y, linear_state
...
>>> x = jnp.ones((1, 3))
>>> y, linear_state = forward(x, linear_state)
>>> y, linear_state = forward(x, linear_state)
...
>>> linear = nnx.merge(*linear_state)
>>> linear.count.value
Array(2, dtype=uint32)

call 返回的代理对象支持索引和属性访问,以访问嵌套方法。在下面的示例中,increment 方法索引用于调用 nodes 字典的 b 键处的 StatefulLinear 模块的 increment 方法。

>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> rngs = nnx.Rngs(0)
>>> nodes = dict(
...   a=StatefulLinear(3, 2, rngs),
...   b=StatefulLinear(2, 1, rngs),
... )
...
>>> node_state = nnx.split(nodes)
>>> # use attribute access
>>> _, node_state = nnx.call(node_state)['b'].increment()
...
>>> nodes = nnx.merge(*node_state)
>>> nodes['a'].count.value
Array(0, dtype=uint32)
>>> nodes['b'].count.value
Array(1, dtype=uint32)
class flax.nnx.GraphDef[源代码]#

一个类,表示 Flax Module 的所有静态、无状态和 Pythonic 部分。GraphDef 可以通过在 Module 上调用 split()graphdef() 来生成。

class flax.nnx.UpdateContext(tag, ref_index, index_ref)[源代码]#

用于处理复杂状态更新的上下文管理器。

merge(graphdef, state, *states)[源代码]#
split(node, *filters)[源代码]#

将图节点拆分为一个 GraphDef 和一个或多个 State。`State` 是一个从字符串或整数到 Variables、数组或嵌套 State 的 `Mapping`。GraphDef 包含重建 Module 图所需的所有静态信息,它类似于 JAX 的 PyTreeDefsplit()merge() 结合使用,可以在图的有状态和无状态表示之间无缝切换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> jax.tree.map(jnp.shape, params)
State({
  'batch_norm': {
    'bias': VariableState(
      type=Param,
      value=(2,)
    ),
    'scale': VariableState(
      type=Param,
      value=(2,)
    )
  },
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})
>>> jax.tree.map(jnp.shape, batch_stats)
State({
  'batch_norm': {
    'mean': VariableState(
      type=BatchStat,
      value=(2,)
    ),
    'var': VariableState(
      type=BatchStat,
      value=(2,)
    )
  }
})
参数
  • node – 要拆分的图节点。

  • *filters – 一些可选的过滤器,用于将状态分组到互斥的子状态中。

返回

GraphDef 和一个或多个 State,其数量等于传递的过滤器数量。如果未传递过滤器,则返回单个 State

flax.nnx.update_context(tag)[源代码]#

创建一个 UpdateContext 上下文管理器,该管理器可用于处理比 nnx.update 可以处理的更复杂的状态更新,包括对静态属性和图结构的更新。

UpdateContext 公开了一个 splitmerge API,其签名与 nnx.split / nnx.merge 相同,但执行一些簿记操作,以便拥有必要的信息,以便根据转换内部所做的更改完美地更新输入对象。UpdateContext 必须总共调用 split 和 merge 4 次,第一次和最后一次调用发生在转换外部,第二次和第三次调用发生在转换内部,如下图所示

                      idxmap
(2) merge ─────────────────────────────► split (3)
      ▲                                    │
      │               inside               │
      │. . . . . . . . . . . . . . . . . . │ index_mapping
      │               outside              │
      │                                    ▼
(1) split──────────────────────────────► merge (4)
                      refmap

第一次调用 split (1) 创建一个 refmap,它跟踪外部引用,第一次调用 merge (2) 创建一个 idxmap,它跟踪内部引用。第二次调用 split (3) 将 refmap 和 idxmap 组合起来,生成 index_mapping,它指示外部引用如何映射到内部引用。最后,最后一次调用 merge (4) 使用 index_mapping 和 refmap 来重建转换的输出,同时重用/更新内部引用。为了避免内存泄漏,在 (3) 之后清除 idxmap,在 (4) 之后清除 refmap,并且在上下文管理器退出后都会清除。

下面是一个简单的示例,演示了 update_context 的用法

>>> from flax import nnx
...
>>> m1 = nnx.Dict({})
>>> with nnx.update_context('example') as ctx:
...   graphdef, state = ctx.split(m1)
...   @jax.jit
...   def f(graphdef, state):
...     m2 = ctx.merge(graphdef, state)
...     m2.a = 1
...     m2.ref = m2  # create a reference cycle
...     return ctx.split(m2)
...   graphdef_out, state_out = f(graphdef, state)
...   m3 = ctx.merge(graphdef_out, state_out)
...
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

请注意,update_context 接受一个 tag 参数,该参数主要用作一种安全机制,以降低在使用 current_update_context() 访问当前活动上下文时意外使用错误 UpdateContext 的风险。current_update_context 可以用作访问当前活动上下文的一种方式,而无需将其作为捕获传递

>>> from flax import nnx
...
>>> m1 = nnx.Dict({})
>>> @jax.jit
... def f(graphdef, state):
...   ctx = nnx.current_update_context('example')
...   m2 = ctx.merge(graphdef, state)
...   m2.a = 1     # insert static attribute
...   m2.ref = m2  # create a reference cycle
...   return ctx.split(m2)
...
>>> @nnx.update_context('example')
... def g(m1):
...   ctx = nnx.current_update_context('example')
...   graphdef, state = ctx.split(m1)
...   graphdef_out, state_out = f(graphdef, state)
...   return ctx.merge(graphdef_out, state_out)
...
>>> m3 = g(m1)
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

如上面的代码所示,update_context 也可以用作装饰器,它在函数执行期间创建/激活 UpdateContext 上下文。可以使用 current_update_context() 访问上下文。

参数

tag – 用于标识上下文的字符串标签。

flax.nnx.current_update_context(tag)[源代码]#

返回给定标签的当前活动的 UpdateContext