指南# 变换 基本示例 变换方法 状态传播 变换子状态(提升类型) 规则和限制 轴元数据 在多个设备上扩展 概述 定义具有指定分片的模型 初始化分片模型 从检查点加载分片模型 编译训练循环 性能分析 逻辑轴注释 使用过滤器 过滤器协议 过滤器 DSL 分组状态 随机性 Rngs、RngStream 和 RngState 过滤随机状态 重新播种 分割 PRNG 密钥 变换 性能考量 异步分发 降低 Python 开销 从 Flax Linen 到 NNX 的演变 基本 Module 定义 变量创建 训练步骤和编译 集合和变量类型 使用多种方法 变换 扫描层 在 Flax NNX 中使用 TrainState 一起使用 Flax NNX 和 Linen 子模块是你所需要的 基础知识 处理 RNG 密钥 NNX 变量类型与 Linen 集合 分区元数据 提升的变换 模型手术 Pythonic nnx.Module 操作 创建没有内存分配的抽象模型或状态 检查点手术 部分初始化 保存和加载检查点 设置 保存检查点 恢复检查点 以纯字典形式保存和恢复 当检查点结构不同时恢复 多进程检查点 其他检查点功能 Flax NNX 与 JAX 变换 差异 混合 Flax NNX 和 JAX 变换 从 Haiku 迁移到 Flax 基本模块定义 变量创建 训练步骤和编译 处理非参数状态 使用多种方法 变换 扫描层 顶层 Haiku 函数与顶层 Flax 模块 示例:使用预训练的 Gemma 进行 Flax NNX 推理 安装 下载模型 Python 导入 加载并准备 Gemma 模型 执行采样/推理