数据类型#

flax.nnx.nn.dtypes.canonicalize_dtype(*args, dtype=None, inexact=True)[源代码]#

将可选的 dtype 规范化为确定的 dtype。

如果 dtype 为 None,此函数将推断 dtype。如果它不为 None,则将返回未修改的值,如果 dtype 无效,则会引发异常。从输入参数中使用 jnp.result_type 推断 dtype。

参数
  • *args – 与 JAX 数组兼容的值。None 值将被忽略。

  • dtype – 可选的 dtype 覆盖。如果指定,则参数将被转换为指定的 dtype,并禁用 dtype 推断。

  • inexact – 当为 True 时,输出 dtype 必须是子类型

  • This (of jnp.inexact. Inexact dtypes are real or complex floating points.) –

  • on (is useful when you want to apply operations that don't work directly) –

  • example. (integers like taking a mean for) –

返回值

应该将 *args 转换为的 dtype。

flax.nnx.nn.dtypes.promote_dtype(args, /, *, dtype=None, inexact=True)[源代码]#

“将输入参数提升为指定的或推断的 dtype。

所有参数都转换为相同的 dtype。有关如何确定此 dtype 的信息,请参阅 canonicalize_dtype

promote_dtype 的行为主要是在 jax.numpy.promote_types 周围的一个便捷包装器。不同之处在于它会自动将所有输入转换为推断的 dtype,允许通过强制 dtype 覆盖推断,并且可以选择检查以保证结果 dtype 是不精确的。

参数
  • *args – 与 JAX 数组兼容的值。None 值将按原样返回。

  • dtype – 可选的 dtype 覆盖。如果指定,则参数将被转换为指定的 dtype,并禁用 dtype 推断。

  • inexact – 当为 True 时,输出 dtype 必须是子类型

  • This (of jnp.inexact. Inexact dtypes are real or complex floating points.) –

  • on (is useful when you want to apply operations that don't work directly) –

  • example. (integers like taking a mean for) –

返回值

转换为相同 dtype 数组的参数。