Optimizer#
- class flax.nnx.optimizer.Optimizer(*args, **kwargs)#
对于使用单个 Optax 优化器的常见情况的简单训练状态。
用法示例
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> state = nnx.Optimizer(model, tx) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(1.7055722, dtype=float32) >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads) >>> loss_fn(model) Array(1.6925814, dtype=float32)
请注意,您可以通过子类化它来轻松扩展此类,以存储其他数据(例如添加指标)。
用法示例
>>> class TrainState(nnx.Optimizer): ... def __init__(self, model, tx, metrics): ... self.metrics = metrics ... super().__init__(model, tx) ... def update(self, *, grads, **updates): ... self.metrics.update(**updates) ... super().update(grads) ... >>> metrics = nnx.metrics.Average() >>> state = TrainState(model, tx, metrics) ... >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.6925814, dtype=float32) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.68612, dtype=float32)
对于更复杂的用例(例如多个优化器),最好 fork 该类并对其进行修改。
- step#
一个
OptState
Variable
,用于跟踪步数。
- model#
包装的
Module
。
- tx#
一个 Optax 梯度变换。
- opt_state#
Optax 优化器状态。
- __init__(model, tx, wrt=<class 'flax.nnx.variablelib.Param'>)#
实例化该类并包装
Module
和 Optax 梯度变换。实例化优化器状态以跟踪wrt
中指定的Variable
类型。将步数设置为 0。- 参数
model – 一个 NNX 模块。
tx – 一个 Optax 梯度变换。
wrt – 可选参数,用于过滤要在优化器状态中跟踪的
Variable
。这些应该是您计划更新的Variable
;即,此参数值应与传递给nnx.grad
调用的wrt
参数匹配,该调用将生成将传递到update()
方法的grads
参数中的梯度。
- update(grads, **kwargs)#
更新返回值中的
step
、params
、opt_state
和**kwargs
。grads
必须来自nnx.grad(..., wrt=self.wrt)
,其中梯度相对于实例化此Optimizer
期间在self.wrt
中定义的相同Variable
类型。 例如>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(model)) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ), 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> # update: >>> # - only Linear layer parameters >>> # - only CustomVariable parameters >>> # - both Linear layer and CustomVariable parameters >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)): ... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad` ... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) ... grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))( ... state.model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) ... state.update(grads=grads)
请注意,此函数内部调用
.tx.update()
,然后调用optax.apply_updates()
以更新params
和opt_state
。- 参数
grads – 来自
nnx.grad
的梯度。**kwargs – 传递给 tx.update 的其他关键字参数,以支持
GradientTransformationExtraArgs –
optax.scale_by_backtracking_linesearch. (例如) –