state#
- class flax.nnx.State(mapping, /, *, _copy=True)[源代码]#
一个类似 pytree 的结构,包含从可哈希和可比较的键到叶子的
Mapping
。叶子可以是任何类型,但VariableState
和Variable
是最常见的。- filter(first, /, *filters)[源代码]#
将
State
过滤成一个或多个State
。用户必须至少传递一个Filter
(即Variable
)。此方法类似于split()
,只是过滤器可以是不完全的。用法示例
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param = state.filter(nnx.Param) >>> batch_stats = state.filter(nnx.BatchStat) >>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
- 参数
first – 第一个过滤器
*filters – 可选的附加过滤器,用于将状态分组为互斥的子状态。
- 返回值
一个或多个
State
,数量等于传递的过滤器数量。
- static merge(state, /, *states)[源代码]#
与
split()
相反。merge
接受一个或多个State
,并创建一个新的State
。用法示例
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> params.linear.bias.value += 1 >>> state = nnx.State.merge(params, batch_stats) >>> nnx.update(model, state) >>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
- 参数
state – 一个
State
对象。*states – 额外的
State
对象。
- 返回值
合并后的
State
。
- split(first, /, *filters)[源代码]#
将
State
拆分为一个或多个State
。用户必须至少传递一个Filter
(即Variable
),并且过滤器必须是详尽的(即它们必须覆盖State
中的所有Variable
类型)。用法示例
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
- 参数
first – 第一个过滤器
*filters – 可选的附加过滤器,用于将状态分组为互斥的子状态。
- 返回值
一个或多个
State
,数量等于传递的过滤器数量。