变量字典#
变量字典是一个普通的 Python 字典,它是一个或多个“变量集合”的容器,每个集合都是嵌套字典,其叶子是 jax.numpy
数组。
不同的变量集合共享相同的嵌套树结构。
例如,考虑以下变量字典
{
"params": {
"Conv1": { "weight": ..., "bias": ... },
"BatchNorm1": { "scale": ..., "mean": ... },
"Conv2": {...}
},
"batch_stats": {
"BatchNorm1": { "moving_mean": ..., "moving_average": ...}
}
}
在这种情况下,"BatchNorm1"
键同时存在于 "params"
和 `"batch_stats""
集合中。这反映了名为 ""BatchNorm1""
的子模块既有可训练参数("params"
集合),也有其他不可训练变量("batch_stats"
集合)。
待办事项:创建“变量字典”设计说明,并从这里链接到它。