Flax#
Neural Networks for JAX
Flax 为使用 JAX 进行神经网络研究和开发的研究人员和开发人员提供了灵活的端到端用户体验。Flax 使您能够充分利用 JAX 的强大功能。
Flax 的核心是 NNX - 一个简化的 API,可以更轻松地在 JAX 中创建、检查、调试和分析神经网络。 Flax NNX 对 Python 引用语义具有一流的支持,使用户能够使用常规 Python 对象来表达其模型。Flax NNX 是之前 Flax Linen API 的演变,它经历了多年的经验,才带来一个更简单和用户友好的 API。
注意
Flax Linen API 在不久的将来不会被弃用,因为大多数 Flax 用户仍然依赖此 API。但是,建议新用户使用 Flax NNX。查看 为什么选择 Flax NNX,了解 Flax NNX 和 Linen 之间的比较,以及我们创建新 API 的原因。
要将您的 Flax Linen 代码库迁移到 Flax NNX,请熟悉 NNX 基础知识中的 API,然后按照演变指南开始迁移。
特性#
Pythonic
Flax NNX 支持使用常规 Python 对象,提供直观和可预测的开发体验。
简单
Flax NNX 依赖于 Python 的对象模型,这为用户带来了简单性并提高了开发速度。
富有表现力
Flax NNX 允许通过其 Filter 系统对模型的状态进行细粒度控制。
熟悉
Flax NNX 通过 Functional API 可以非常容易地将对象与常规 JAX 代码集成。
基本用法#
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # in-place updates
return loss