变量字典

内容

变量字典#

变量字典是一个普通的 Python 字典,它是一个或多个“变量集合”的容器,每个集合都是嵌套字典,其叶子是 jax.numpy 数组。

不同的变量集合共享相同的嵌套树结构。

例如,考虑以下变量字典

{
  "params": {
    "Conv1": { "weight": ..., "bias": ... },
    "BatchNorm1": { "scale": ..., "mean": ... },
    "Conv2": {...}
  },
  "batch_stats": {
    "BatchNorm1": { "moving_mean": ..., "moving_average": ...}
  }
}

在这种情况下,"BatchNorm1" 键同时存在于 "params"`"batch_stats"" 集合中。这反映了名为 ""BatchNorm1"" 的子模块既有可训练参数("params" 集合),也有其他不可训练变量("batch_stats" 集合)。

待办事项:创建“变量字典”设计说明,并从这里链接到它。

class flax.linen.Variable(scope, collection, name, unbox)[source]#

Variable 对象允许对 VariableDict 中的变量进行可变访问。

变量由集合(例如,“batch_stats”)和名称(例如,“moving_mean”)标识。value 属性提供对变量内容的访问,可以对其进行赋值以进行修改。