检查#
- flax.linen.tabulate(module, rngs, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#
返回一个创建模块摘要作为表格的函数。
此函数接受大多数相同的参数,并在内部调用Module.init,但它返回一个形式为(*args, **kwargs) -> str的函数,其中*args和**kwargs在正向传递期间传递给method(例如__call__)。
tabulate 在内部使用jax.eval_shape 来运行正向计算,而不会消耗任何 FLOPs 或分配内存。
可以将其他参数传递到console_kwargs 参数中,例如{‘width’: 120}。有关 console_kwargs 参数的完整列表,请参见:https://rich.pythonlang.cn/en/stable/reference/console.html#rich.console.Console
示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> tabulate_fn = nn.tabulate( ... Foo(), jax.random.key(0), compute_flops=True, compute_vjp_flops=True) >>> # print(tabulate_fn(x))
这将给出以下输出
Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B)
注意:表格中行的顺序不代表执行顺序,而是与variables 中的键的顺序一致,这些键按字母顺序排序。
注意:如果模块不可微分,则vjp_flops 返回0。
- 参数
module – 要制表的模块。
rngs – 作为变量集合传递给Module.init 的 rng。
depth – 控制摘要可以深入多少个子模块。默认情况下为None,这意味着没有限制。如果某个子模块由于深度限制而未显示,则其参数计数和字节将添加到其第一个显示的祖先的行中,以便所有行的总和始终加起来等于模块的总参数数量。
show_repeated – 如果为True,则表格中将显示对同一模块的重复调用,否则只显示第一个调用。默认值为False。
mutable – 可以是 bool、str 或 list。指定哪些集合应该被视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称列表。默认情况下,除了‘intermediates’ 之外的所有集合都是可变的。console_kwargs – 可选字典,其中包含传递给rich.console.Console 的其他关键字参数,用于呈现表格。默认参数为{‘force_terminal’: True, ‘force_jupyter’: False}。
table_kwargs – 可选字典,其中包含传递给rich.table.Table 构造函数的其他关键字参数。
column_kwargs – 可选字典,其中包含传递给rich.table.Table.add_column 的其他关键字参数,用于向表格添加列。
compute_flops – 是否在表格中包含flops 列,列出每个模块正向传递的估计 FLOPs 成本。确实会产生实际的设备上计算/编译/内存分配,但对于大型模块仍然会引入开销(例如,Stable Diffusion 的 UNet 额外 20 秒,而其他制表将在 5 秒内完成)。
compute_vjp_flops – 是否在表格中包含vjp_flops 列,列出每个模块反向传递的估计 FLOPs 成本。引入了大约 2-3 倍于compute_flops 的计算开销。
**kwargs – 传递给Module.init 的其他参数。
- 返回
一个接受正向传递(method)的相同*args 和**kwargs,并返回包含模块表格表示的字符串的函数。