flax.serialization 包#

Jax 的序列化实用程序。

所有带有状态的 Flax 类(例如,Optimizer)都可以转换为 numpy 数组的状态字典,以便于序列化。

状态字典#

flax.serialization.from_state_dict(target, state, name='.')[source]#

使用状态字典恢复给定目标的状态。

此函数将当前目标作为参数。这让我们知道目标的确切结构,并让我们添加断言以确保形状和数据类型不会改变。

实际上,target 中的任何叶子值都不会被实际使用。仅使用树结构、形状和数据类型。

参数
  • target – 应恢复其状态的对象。

  • state – 由 to_state_dict 生成的字典,其中包含 target 的所需新状态。

  • name – 获取的分支名称,用于改进反序列化错误消息。

返回值

具有已恢复状态的对象副本。

flax.serialization.to_state_dict(target)[source]#

返回一个包含给定目标状态的字典。

flax.serialization.register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False)[source]#

注册用于序列化的类型。

参数
  • ty – 要注册的类型

  • ty_to_state_dict – 一个函数,它接收 ty 的实例并将其状态作为字典返回。

  • ty_from_state_dict – 一个函数,它接收 ty 的实例和一个状态字典,并返回具有已恢复状态的实例的副本。

  • override – 覆盖先前注册的序列化处理程序(默认值:False)。

使用 MessagePack 进行序列化#

flax.serialization.msgpack_serialize(pytree, in_place=False)[source]#

将数据结构保存为 msgpack 格式的字节。

仅支持具有数组叶子的 python 树的低级函数,对于自定义对象,请使用 to_bytes。它将超过 MAX_CHUNK_SIZE 的数组拆分为多个块。

参数
  • pytree – 包含 python 原语和数组叶子的字典、列表、元组的 python 树。

  • in_place – 布尔值,指定是否应就地修改 pytree。

返回值

pytree 的 msgpack 编码字节。

flax.serialization.msgpack_restore(encoded_pytree)[source]#

从 msgpack 格式的字节中恢复数据结构。

仅支持具有数组叶子的 python 树的低级函数,对于自定义对象,请使用 from_bytes

参数

encoded_pytree – python 树的 msgpack 编码字节。

返回值

包含 python 原语和数组叶子的字典、列表、元组的 python 树。

flax.serialization.to_bytes(target)[source]#

将优化器或其他对象保存为 msgpack 序列化的状态字典。

参数

target – 带有状态字典注册的模板对象,要序列化为 msgpack 格式。通常是 flax 模型或优化器。

返回值

target 对象的 msgpack 编码状态字典的字节。

flax.serialization.from_bytes(target, encoded_bytes)[source]#

从 msgpack 序列化的状态字典中恢复优化器或其他对象。

参数
  • target – 与从 encoded_bytes 反序列化的结构匹配的,带有状态字典注册的模板对象。

  • encoded_bytes – 与 target 结构同构的 msgpack 序列化对象。通常是 flax 模型或优化器。

返回值

一个新的与 target 结构同构的对象,其中包含来自保存数据的更新叶子数据。