亚麻#

使用 JAX 的神经网络


亚麻为使用 JAX 和神经网络的研究人员提供端到端且灵活的用户体验。亚麻提供了JAX 的全部功能。它由松散耦合的库组成,这些库通过端到端集成的指南示例 展示出来。

亚麻被数百个项目(还在不断增长) 使用,既包括开源社区(如 Hugging Face)也包括谷歌(如 GeminiImagenScenicBig Vision)。

特性#

安全

亚麻的设计目标是正确性和安全性。由于其不可变的模块和函数式 API,亚麻有助于缓解在 JAX 中处理状态时出现的错误。

控制

亚麻通过其变量集合、RNG 集合和可变性条件,提供了比大多数神经网络框架更细粒度的控制和表达能力。

函数式 API

亚麻的函数式 API 通过 vmap、scan 等提升的转换从根本上重新定义了模块的功能,同时还实现了与 Optax 和 Chex 等其他 JAX 库的无缝集成。

简洁的代码

亚麻的compact 模块允许直接在调用位置定义子模块,从而使代码更易于阅读并避免重复。


安装#

pip install flax
# or to install the latest version of Flax:
pip install --upgrade git+https://github.com/google/flax.git

亚麻安装了 JAX 的普通 CPU 版本,如果你需要自定义版本,请查看JAX 的安装页面

基本用法#

class MLP(nn.Module):                    # create a Flax Module dataclass
  out_dims: int

  @nn.compact
  def __call__(self, x):
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(128)(x)                 # create inline Flax Module submodules
    x = nn.relu(x)
    x = nn.Dense(self.out_dims)(x)       # shape inference
    return x

model = MLP(out_dims=10)                 # instantiate the MLP model

x = jnp.empty((4, 28, 28, 1))            # generate random data
variables = model.init(random.key(42), x)# initialize the weights
y = model.apply(variables, x)            # make forward pass

了解更多#

快速入门
quick_start.html
词汇表
glossary.html
开发人员说明
developer_notes/index.html
亚麻哲学
philosophy.html

生态系统#

亚麻中的知名示例包括

NLP 和计算机视觉模型

用于文本到图像生成的模型

用于文本生成的 5400 亿参数模型

文本到图像扩散模型

用于大规模计算机视觉的库

大规模计算机视觉模型

开源高性能 LLM

大型语言模型

设备上可微分强化学习环境