检查

内容

检查#

flax.linen.tabulate(module, rngs, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#

返回一个创建模块摘要作为表格的函数。

此函数接受大多数相同的参数,并在内部调用Module.init,但它返回一个形式为(*args, **kwargs) -> str的函数,其中*args**kwargs在正向传递期间传递给method(例如__call__)。

tabulate 在内部使用jax.eval_shape 来运行正向计算,而不会消耗任何 FLOPs 或分配内存。

可以将其他参数传递到console_kwargs 参数中,例如{‘width’: 120}。有关 console_kwargs 参数的完整列表,请参见:https://rich.pythonlang.cn/en/stable/reference/console.html#rich.console.Console

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> tabulate_fn = nn.tabulate(
...     Foo(), jax.random.key(0), compute_flops=True, compute_vjp_flops=True)

>>> # print(tabulate_fn(x))

这将给出以下输出

                                       Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                               Total Parameters: 50 (200 B)

注意:表格中行的顺序不代表执行顺序,而是与variables 中的键的顺序一致,这些键按字母顺序排序。

注意:如果模块不可微分,则vjp_flops 返回0

参数
  • module – 要制表的模块。

  • rngs – 作为变量集合传递给Module.init 的 rng。

  • depth – 控制摘要可以深入多少个子模块。默认情况下为None,这意味着没有限制。如果某个子模块由于深度限制而未显示,则其参数计数和字节将添加到其第一个显示的祖先的行中,以便所有行的总和始终加起来等于模块的总参数数量。

  • show_repeated – 如果为True,则表格中将显示对同一模块的重复调用,否则只显示第一个调用。默认值为False

  • mutable – 可以是 bool、str 或 list。指定哪些集合应该被视为可变的:bool:所有/没有集合是可变的。 str:单个可变集合的名称。 list:可变集合名称列表。默认情况下,除了‘intermediates’ 之外的所有集合都是可变的。

  • console_kwargs – 可选字典,其中包含传递给rich.console.Console 的其他关键字参数,用于呈现表格。默认参数为{‘force_terminal’: True, ‘force_jupyter’: False}

  • table_kwargs – 可选字典,其中包含传递给rich.table.Table 构造函数的其他关键字参数。

  • column_kwargs – 可选字典,其中包含传递给rich.table.Table.add_column 的其他关键字参数,用于向表格添加列。

  • compute_flops – 是否在表格中包含flops 列,列出每个模块正向传递的估计 FLOPs 成本。确实会产生实际的设备上计算/编译/内存分配,但对于大型模块仍然会引入开销(例如,Stable Diffusion 的 UNet 额外 20 秒,而其他制表将在 5 秒内完成)。

  • compute_vjp_flops – 是否在表格中包含vjp_flops 列,列出每个模块反向传递的估计 FLOPs 成本。引入了大约 2-3 倍于compute_flops 的计算开销。

  • **kwargs – 传递给Module.init 的其他参数。

返回

一个接受正向传递(method)的相同*args**kwargs,并返回包含模块表格表示的字符串的函数。