SPMD#

用于处理 jit 和分区模型的实用程序。

该模块引入了axis_ruleslogical_to_mesh_axeslogical_to_meshwith_logical_constraint,以便根据“逻辑命名轴”而不是 jit 的默认网格轴应用 jit 分区约束。

此外,还定义了 LogicallyPartitioned 元数据包装器以及初始化器函数包装器 ``with_logical_partitioning,用于将逻辑轴元数据引入模型的变量。

flax.linen.Partitioned(value, names, mesh=None)[source]#

用于分区元数据的包装器。

Partitioned 用于为变量扩展 jax.experimental.pjit 所需的分区信息。

定义分区变量的最简单方法是使用围绕变量初始化器的 with_partitioning 包装器。

示例

class MLP(nn.Module):
  hidden_size: int
  @nn.compact
  def __call__(self, x):
    ki = nn.linear.default_kernel_init
    h = nn.Dense(
        self.hidden_size,
        kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x)
    h = nn.relu(h)
    return nn.Dense(
        x.shape[-1],
        kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h)
mlp = MLP(4096)
x = jnp.ones((8 * 1024, 1024))
# use eval_shape to get the Partitioned instances for the variables.
# this way we can determine the PartitionSpecs for the init variables
# before we call the init fn.
var_spec = nn.get_partition_spec(
    jax.eval_shape(mlp.init, random.key(0), x))
init_fn = mesh(pjit(mlp.init,
                    (None, PartitionSpec("data", "model")), var_spec))
variables = init_fn(random.key(0), x)
apply_fn = mesh(pjit(
    mlp.apply,
    (var_spec, PartitionSpec("data", "model")),
     PartitionSpec("data", "model")))
apply_fn(variables, x)

当使用 nn.vmapnn.scan 等转换时,Partitioned 值可以获得额外的轴。在这种情况下,您可以在 vmap/scan 中使用 metadata_params 参数指定新轴的名称。

class Model(nn.Module):
@nn.compact
def __call__(self, x):
  def body(mdl, c):
    c = MLP(4096)(c)
    return c, ()
  c, _ = nn.scan(
      body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8,
      metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x)
  return c
flax.linen.with_partitioning(fn, names, mesh=None)[source]#

使用 Partitioned 包装函数的返回值。

示例

>>> import flax.linen as nn
>>> kernel_init = nn.with_partitioning(
...     nn.initializers.lecun_normal(), (None, "data"))
>>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
参数
  • fn – 要包装的函数。通常,这是一个初始化器。

  • names – 传递给 Partitioned 的逻辑轴。

  • mesh – 用于分区的网格。如果为 None,则使用全局网格资源(如果可用)。

返回值

一个包装 fn 的函数,它将返回一个 Partitioned 实例。

flax.linen.get_partition_spec(tree)[source]#

从包含 Partitioned 值的 PyTree 中提取 PartitionSpec 树。

flax.linen.get_sharding(tree, mesh)[source]#

从包含 Partitioned 值和网格的 PyTree 中提取 jax.sharding 树。

flax.linen.LogicallyPartitioned(value: Any, names: tuple[Optional[str], ...], mesh: jax._src.mesh.Mesh | None = None, rules: collections.abc.Sequence[tuple[str, Union[str, tuple[str, ...], NoneType]]] | None = None)[source]#
flax.linen.logical_axis_rules(rules)[source]#

用于设置逻辑到网格轴绑定关系的上下文管理器。

flax.linen.set_logical_axis_rules(rules)[source]#

设置全局逻辑轴到网格轴绑定关系。

flax.linen.get_logical_axis_rules()[source]#

返回全局逻辑轴到网格轴绑定关系。

flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#

计算数组的布局。

规则按优先级顺序排列,由对组成:(ArrayDimensionName, MeshDimensionName),这意味着给定数组维度(如果存在且未使用)应该在给定网格维度(如果存在且未使用)上进行分片。

数组的布局表示为元组,每个元素对应数组中的一个维度。元素为 None,或为网格维度的名称,表示数组的此维度在网格的此维度上进行分片。

例如,给定一个具有

array_dim_names = ('batch', 'length', 'heads', 'features')

的数组,布局规则为

rules = (('batch', 'X'),
         ('features', 'X'),
         ('heads', 'Y'),
         ('batch', 'Z'))

那么此函数将返回

PartitionSpec('X', None, 'Y', None)
参数
  • array_dim_names – 数组维度名称的元组或 None。

  • rules – 可选的逻辑到网格规则覆盖。默认情况下,使用从 axis_rules 函数定义的动态上下文设置的规则。

返回值

参数的 PartitionSpec。

flax.linen.logical_to_mesh(tree, rules=None)[source]#

将逻辑 PartitionSpecs 的 PyTrees 应用于 logical_to_mesh_axes。

flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[source]#

将逻辑 PartitionSpecs 的 PyTrees 转换为分片。

flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#

使用逻辑轴名称的 jit’s with_sharding_constraint 版本。

flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#

使用 LogicallyPartitioned 包装函数的返回值。

示例

>>> import flax.linen as nn
>>> kernel_init = nn.with_logical_partitioning(
...     nn.initializers.lecun_normal(), (None, "data"))
>>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
参数
  • fn – 要包装的函数。通常,这是一个初始化器。

  • names – 传递给 LogicallyPartitioned 的逻辑轴。

  • mesh – 用于分区的网格。如果为 None,则使用全局网格资源(如果可用)。

  • rules – 可选的逻辑到网格规则使用。如果为 None,则在可用时使用全局规则。

返回值

包装 fn 的函数,它将返回一个 LogicallyPartitioned 实例。