指标#

class flax.nnx.metrics.Metric(*args, **kwargs)#

指标的基类。任何继承自 Metric 的类都应实现 computeresetupdate 方法。

__init__()#
compute()#

计算并返回 Metric 的值。

reset()#

就地重置 Metric

update(**kwargs)#

就地更新 Metric

class flax.nnx.metrics.Average(*args, **kwargs)#

平均值指标。

用法示例

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Average()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Array(2.5, dtype=float32)
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Array(2., dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
__init__(argname='values')#

传入一个字符串,表示 update() 将用来获取新值的关键字参数。例如,将指标构造为 avg = Average('test') 将允许您使用 avg.update(test=new_value) 进行更新。

参数

argname – 一个可选字符串,表示 update() 将用来获取新值的关键字参数。默认为 'values'

compute()#

计算并返回平均值。

reset()#

重置此 Metric

update(**kwargs)#

就地更新此 Metric。此方法将使用 kwargs[self.argname] 中的值更新指标,其中 self.argname 在构造时定义。

参数

**kwargs – 包含 self.argname 条目的关键字参数,该条目映射到我们想要用来更新此指标的值。

class flax.nnx.metrics.Accuracy(*args, **kwargs)#

准确率指标。此指标继承自 Average,因此它们共享相同的 resetcompute 方法实现。与 Average 不同,在构造期间不需要向 Accuracy 传递字符串。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> metrics = nnx.metrics.Accuracy()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(logits=logits, labels=labels)
>>> metrics.compute()
Array(0.6, dtype=float32)
>>> metrics.update(logits=logits2, labels=labels2)
>>> metrics.compute()
Array(0.7, dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
update(*, logits, labels, **_)#

就地更新此 Metric

参数
  • logits – 输出的预测激活值。在将这些值与标签进行比较之前,会对它们进行 argmax 操作(在尾部维度上)。

  • labels – 真实整数标签。

class flax.nnx.metrics.Welford(*args, **kwargs)#

使用 Welford 算法来计算数据流的均值和方差。

用法示例

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Welford()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
>>> metrics.reset()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
__init__(argname='values')#

传入一个字符串,表示 update() 将用来获取新值的关键字参数。例如,将指标构造为 wf = Welford('test') 将允许您使用 wf.update(test=new_value) 进行更新。

参数

argname – 一个可选字符串,表示 update() 将用来获取新值的关键字参数。默认为 'values'

compute()#

计算并返回 Statistics 数据类对象中的均值和方差统计信息。

reset()#

重置此 Metric

update(**kwargs)#

就地更新此 Metric。此方法将使用 kwargs[self.argname] 中的值更新指标,其中 self.argname 在构造时定义。

参数

**kwargs – 包含 self.argname 条目的关键字参数,该条目映射到我们想要用来更新此指标的值。

class flax.nnx.metrics.MultiMetric(*args, **kwargs)#

MultiMetric 类,用于存储多个指标并在单个调用中更新它们。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> metrics = nnx.MultiMetric(
...   accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
... )

>>> metrics
MultiMetric(
  accuracy=Accuracy(
    argname='values',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  ),
  loss=Average(
    argname='values',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  )
)

>>> metrics.accuracy
Accuracy(
  argname='values',
  total=MetricState(
    value=Array(0., dtype=float32)
  ),
  count=MetricState(
    value=Array(0, dtype=int32)
  )
)

>>> metrics.loss
Average(
  argname='values',
  total=MetricState(
    value=Array(0., dtype=float32)
  ),
  count=MetricState(
    value=Array(0, dtype=int32)
  )
)

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
>>> metrics.compute()
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
>>> metrics.compute()
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
>>> metrics.reset()
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
__init__(**metrics)#

将关键字参数传递给构造函数,例如 MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)

参数

**metrics – 将用于访问相应 Metric 的关键字参数。

compute()#

计算并返回所有底层 Metric 的值。此方法将返回一个字典,将字符串(由传递给构造函数的关键字参数 **metrics 定义)映射到相应的度量值。

reset()#

重置所有底层 Metric

update(**updates)#

就地更新此 MultiMetric 中的所有底层 Metric。所有 **updates 都将传递给所有底层 Metricupdate 方法。

参数

**updates – 将传递给底层 Metricupdate 方法的关键字参数。