指标#
- class flax.nnx.metrics.Metric(*args, **kwargs)#
指标的基类。任何继承自
Metric
的类都应实现compute
、reset
和update
方法。- __init__()#
- compute()#
计算并返回
Metric
的值。
- reset()#
就地重置
Metric
。
- update(**kwargs)#
就地更新
Metric
。
- class flax.nnx.metrics.Average(*args, **kwargs)#
平均值指标。
用法示例
>>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Average() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(values=batch_loss) >>> metrics.compute() Array(2.5, dtype=float32) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Array(2., dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
- __init__(argname='values')#
传入一个字符串,表示
update()
将用来获取新值的关键字参数。例如,将指标构造为avg = Average('test')
将允许您使用avg.update(test=new_value)
进行更新。- 参数
argname – 一个可选字符串,表示
update()
将用来获取新值的关键字参数。默认为'values'
。
- compute()#
计算并返回平均值。
- reset()#
重置此
Metric
。
- update(**kwargs)#
就地更新此
Metric
。此方法将使用kwargs[self.argname]
中的值更新指标,其中self.argname
在构造时定义。- 参数
**kwargs – 包含
self.argname
条目的关键字参数,该条目映射到我们想要用来更新此指标的值。
- class flax.nnx.metrics.Accuracy(*args, **kwargs)#
准确率指标。此指标继承自
Average
,因此它们共享相同的reset
和compute
方法实现。与Average
不同,在构造期间不需要向Accuracy
传递字符串。用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> metrics = nnx.metrics.Accuracy() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(logits=logits, labels=labels) >>> metrics.compute() Array(0.6, dtype=float32) >>> metrics.update(logits=logits2, labels=labels2) >>> metrics.compute() Array(0.7, dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
- update(*, logits, labels, **_)#
就地更新此
Metric
。- 参数
logits – 输出的预测激活值。在将这些值与标签进行比较之前,会对它们进行 argmax 操作(在尾部维度上)。
labels – 真实整数标签。
- class flax.nnx.metrics.Welford(*args, **kwargs)#
使用 Welford 算法来计算数据流的均值和方差。
用法示例
>>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Welford() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32)) >>> metrics.update(values=batch_loss) >>> metrics.compute() Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32)) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32)) >>> metrics.reset() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
- __init__(argname='values')#
传入一个字符串,表示
update()
将用来获取新值的关键字参数。例如,将指标构造为wf = Welford('test')
将允许您使用wf.update(test=new_value)
进行更新。- 参数
argname – 一个可选字符串,表示
update()
将用来获取新值的关键字参数。默认为'values'
。
- compute()#
计算并返回
Statistics
数据类对象中的均值和方差统计信息。
- reset()#
重置此
Metric
。
- update(**kwargs)#
就地更新此
Metric
。此方法将使用kwargs[self.argname]
中的值更新指标,其中self.argname
在构造时定义。- 参数
**kwargs – 包含
self.argname
条目的关键字参数,该条目映射到我们想要用来更新此指标的值。
- class flax.nnx.metrics.MultiMetric(*args, **kwargs)#
MultiMetric 类,用于存储多个指标并在单个调用中更新它们。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> metrics = nnx.MultiMetric( ... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() ... ) >>> metrics MultiMetric( accuracy=Accuracy( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ), loss=Average( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) ) >>> metrics.accuracy Accuracy( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) >>> metrics.loss Average( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} >>> metrics.update(logits=logits, labels=labels, values=batch_loss) >>> metrics.compute() {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)} >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2) >>> metrics.compute() {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)} >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
- __init__(**metrics)#
将关键字参数传递给构造函数,例如
MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)
。- 参数
**metrics – 将用于访问相应
Metric
的关键字参数。
- compute()#
计算并返回所有底层
Metric
的值。此方法将返回一个字典,将字符串(由传递给构造函数的关键字参数**metrics
定义)映射到相应的度量值。
- reset()#
重置所有底层
Metric
。
- update(**updates)#
就地更新此
MultiMetric
中的所有底层Metric
。所有**updates
都将传递给所有底层Metric
的update
方法。- 参数
**updates – 将传递给底层
Metric
的update
方法的关键字参数。