state#

class flax.nnx.State(mapping, /, *, _copy=True)[源代码]#

一个类似 pytree 的结构,包含从可哈希和可比较的键到叶子的 Mapping。叶子可以是任何类型,但 VariableStateVariable 是最常见的。

filter(first, /, *filters)[源代码]#

State 过滤成一个或多个 State。用户必须至少传递一个 Filter(即 Variable)。此方法类似于 split(),只是过滤器可以是不完全的。

用法示例

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param = state.filter(nnx.Param)
>>> batch_stats = state.filter(nnx.BatchStat)
>>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
参数
  • first – 第一个过滤器

  • *filters – 可选的附加过滤器,用于将状态分组为互斥的子状态。

返回值

一个或多个 State,数量等于传递的过滤器数量。

static merge(state, /, *states)[源代码]#

split() 相反。

merge 接受一个或多个 State,并创建一个新的 State

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> params.linear.bias.value += 1

>>> state = nnx.State.merge(params, batch_stats)
>>> nnx.update(model, state)
>>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
参数
  • state – 一个 State 对象。

  • *states – 额外的 State 对象。

返回值

合并后的 State

split(first, /, *filters)[源代码]#

State 拆分为一个或多个 State。用户必须至少传递一个 Filter(即 Variable),并且过滤器必须是详尽的(即它们必须覆盖 State 中的所有 Variable 类型)。

用法示例

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
参数
  • first – 第一个过滤器

  • *filters – 可选的附加过滤器,用于将状态分组为互斥的子状态。

返回值

一个或多个 State,数量等于传递的过滤器数量。