SPMD#
用于处理 jit 和分区模型的实用程序。
该模块引入了axis_rules
、logical_to_mesh_axes
、logical_to_mesh
、with_logical_constraint
,以便根据“逻辑命名轴”而不是 jit 的默认网格轴应用 jit 分区约束。
此外,还定义了 LogicallyPartitioned
元数据包装器以及初始化器函数包装器 ``with_logical_partitioning,用于将逻辑轴元数据引入模型的变量。
- flax.linen.Partitioned(value, names, mesh=None)[source]#
用于分区元数据的包装器。
Partitioned
用于为变量扩展jax.experimental.pjit
所需的分区信息。定义分区变量的最简单方法是使用围绕变量初始化器的
with_partitioning
包装器。示例
class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. # this way we can determine the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), PartitionSpec("data", "model"))) apply_fn(variables, x)
当使用
nn.vmap
和nn.scan
等转换时,Partitioned
值可以获得额外的轴。在这种情况下,您可以在 vmap/scan 中使用 metadata_params 参数指定新轴的名称。class Model(nn.Module): @nn.compact def __call__(self, x): def body(mdl, c): c = MLP(4096)(c) return c, () c, _ = nn.scan( body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8, metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x) return c
- flax.linen.with_partitioning(fn, names, mesh=None)[source]#
使用 Partitioned 包装函数的返回值。
示例
>>> import flax.linen as nn >>> kernel_init = nn.with_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- 参数
fn – 要包装的函数。通常,这是一个初始化器。
names – 传递给
Partitioned
的逻辑轴。mesh – 用于分区的网格。如果为 None,则使用全局网格资源(如果可用)。
- 返回值
一个包装
fn
的函数,它将返回一个Partitioned
实例。
- flax.linen.LogicallyPartitioned(value: Any, names: tuple[Optional[str], ...], mesh: jax._src.mesh.Mesh | None = None, rules: collections.abc.Sequence[tuple[str, Union[str, tuple[str, ...], NoneType]]] | None = None)[source]#
- flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#
计算数组的布局。
规则按优先级顺序排列,由对组成:
(ArrayDimensionName, MeshDimensionName)
,这意味着给定数组维度(如果存在且未使用)应该在给定网格维度(如果存在且未使用)上进行分片。数组的布局表示为元组,每个元素对应数组中的一个维度。元素为 None,或为网格维度的名称,表示数组的此维度在网格的此维度上进行分片。
例如,给定一个具有
array_dim_names = ('batch', 'length', 'heads', 'features')
的数组,布局规则为
rules = (('batch', 'X'), ('features', 'X'), ('heads', 'Y'), ('batch', 'Z'))
那么此函数将返回
PartitionSpec('X', None, 'Y', None)
- 参数
array_dim_names – 数组维度名称的元组或 None。
rules – 可选的逻辑到网格规则覆盖。默认情况下,使用从
axis_rules
函数定义的动态上下文设置的规则。
- 返回值
参数的 PartitionSpec。
- flax.linen.logical_to_mesh(tree, rules=None)[source]#
将逻辑 PartitionSpecs 的 PyTrees 应用于 logical_to_mesh_axes。
- flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[source]#
将逻辑 PartitionSpecs 的 PyTrees 转换为分片。
- flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#
使用逻辑轴名称的 jit’s with_sharding_constraint 版本。
- flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#
使用 LogicallyPartitioned 包装函数的返回值。
示例
>>> import flax.linen as nn >>> kernel_init = nn.with_logical_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- 参数
fn – 要包装的函数。通常,这是一个初始化器。
names – 传递给
LogicallyPartitioned
的逻辑轴。mesh – 用于分区的网格。如果为 None,则使用全局网格资源(如果可用)。
rules – 可选的逻辑到网格规则使用。如果为 None,则在可用时使用全局规则。
- 返回值
包装
fn
的函数,它将返回一个LogicallyPartitioned
实例。