flax.serialization 包
Jax 的序列化实用程序。
所有带有状态的 Flax 类(例如,Optimizer)都可以转换为 numpy 数组的状态字典,以便于序列化。
状态字典
-
flax.serialization.from_state_dict(target, state, name='.')[source]
使用状态字典恢复给定目标的状态。
此函数将当前目标作为参数。这让我们知道目标的确切结构,并让我们添加断言以确保形状和数据类型不会改变。
实际上,target
中的任何叶子值都不会被实际使用。仅使用树结构、形状和数据类型。
- 参数
-
- 返回值
具有已恢复状态的对象副本。
-
flax.serialization.to_state_dict(target)[source]
返回一个包含给定目标状态的字典。
-
flax.serialization.register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False)[source]
注册用于序列化的类型。
- 参数
ty – 要注册的类型
ty_to_state_dict – 一个函数,它接收 ty 的实例并将其状态作为字典返回。
ty_from_state_dict – 一个函数,它接收 ty 的实例和一个状态字典,并返回具有已恢复状态的实例的副本。
override – 覆盖先前注册的序列化处理程序(默认值:False)。
使用 MessagePack 进行序列化
-
flax.serialization.msgpack_serialize(pytree, in_place=False)[source]
将数据结构保存为 msgpack 格式的字节。
仅支持具有数组叶子的 python 树的低级函数,对于自定义对象,请使用 to_bytes
。它将超过 MAX_CHUNK_SIZE 的数组拆分为多个块。
- 参数
-
- 返回值
pytree 的 msgpack 编码字节。
-
flax.serialization.msgpack_restore(encoded_pytree)[source]
从 msgpack 格式的字节中恢复数据结构。
仅支持具有数组叶子的 python 树的低级函数,对于自定义对象,请使用 from_bytes
。
- 参数
encoded_pytree – python 树的 msgpack 编码字节。
- 返回值
包含 python 原语和数组叶子的字典、列表、元组的 python 树。
-
flax.serialization.to_bytes(target)[source]
将优化器或其他对象保存为 msgpack 序列化的状态字典。
- 参数
target – 带有状态字典注册的模板对象,要序列化为 msgpack 格式。通常是 flax 模型或优化器。
- 返回值
target
对象的 msgpack 编码状态字典的字节。
-
flax.serialization.from_bytes(target, encoded_bytes)[source]
从 msgpack 序列化的状态字典中恢复优化器或其他对象。
- 参数
-
- 返回值
一个新的与 target
结构同构的对象,其中包含来自保存数据的更新叶子数据。