flax.struct 包#
用于定义可与 jax 变换一起使用的自定义类的实用程序。
- flax.struct.dataclass(clz, **kwargs)[source]#
创建一个可以传递给函数变换的类。
注意
从
PyTreeNode
继承以避免在使用 PyType 时出现类型检查问题。像
jax.jit
和jax.grad
这样的 Jax 变换需要不可变的对象,并且可以使用jax.tree_util
方法对其进行映射。dataclass
装饰器使定义可以安全地传递给 Jax 的自定义类变得容易。例如>>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> @struct.dataclass ... class Model: ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model)
注意,数据类具有自动生成的
__init__
,其中构造函数的参数和创建实例的属性一一对应。这种对应关系使这些对象成为有效的容器,这些容器可以与 JAX 变换以及更广泛的jax.tree_util
库一起使用。有时需要“智能构造函数”,例如,因为某些属性可以(可选地)从其他属性派生出来。使用 Flax 数据类执行此操作的方法是创建一个静态或类方法,该方法提供智能构造函数。这样一来,就可以保留
jax.tree_util
使用的简单构造函数。考虑以下示例>>> @struct.dataclass ... class DirectionAndScaleKernel: ... direction: jax.Array ... scale: jax.Array ... @classmethod ... def create(cls, kernel): ... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) ... direction = direction / scale ... return cls(direction, scale)
- 参数
clz – 将被装饰器转换的类。
- 返回值
新类。
- class flax.struct.PyTreeNode(*args, **kwargs)[source]#
应该像 JAX pytree 节点一样工作的 dataclass 的基类。
请参阅
flax.struct.dataclass
以获取jax.tree_util
行为。该基类还避免了使用 PyType 时出现类型检查错误。示例
>>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> class Model(struct.PyTreeNode): ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model)