示例:使用预训练的 Gemma 和 Flax NNX 进行推理

示例:使用预训练的 Gemma 和 Flax NNX 进行推理#

此示例演示如何使用 Flax NNX 加载 Gemma 开源模型文件,并使用它们执行采样/推理以生成文本。你将使用用 Flax 和 JAX 编写的 Flax NNX gemma 模块进行模型参数配置和推理。

Gemma 是基于 Google DeepMind 的 Gemini 的一系列轻量级、最先进的开源模型。阅读更多关于 GemmaGemma 2 的信息。

建议使用可以访问 A100 GPU 加速的 Google Colab 来运行代码。

安装#

安装必要的依赖项,包括 kagglehub

! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope

下载模型#

要使用 Gemma 模型,你需要一个 Kaggle 账户和 API 密钥

  1. 要创建账户,请访问 Kaggle 并单击“注册”。

  2. 如果你有账户(或创建账户后),你需要登录,转到你的 “设置”,然后在“API”下单击“创建新令牌”以生成并下载你的 Kaggle API 密钥。

  3. Google Colab 中,在“密钥”下添加你的 Kaggle 用户名和 API 密钥,将用户名存储为 KAGGLE_USERNAME,密钥存储为 KAGGLE_KEY。如果你使用的是 Kaggle Notebook 进行免费 TPU 或其他硬件加速,它在“加载项” > “密钥”下有一个密钥存储功能,以及访问存储密钥的说明。

然后运行下面的单元格。

import kagglehub
kagglehub.login()

如果一切顺利,应该显示 Kaggle 凭据已设置。 Kaggle 凭据已成功验证。

注意:在 Google Colab 中,你也可以按照上面的可选步骤 3,使用下面的代码验证 Kaggle 身份。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

现在,加载你要尝试的 Gemma 模型。下一个单元格中的代码使用 kagglehub.model_download 下载模型文件。

注意:对于较大的模型,例如 gemma 7bgemma 7b-it(指令),你可能需要具有大量内存的硬件加速器,例如 NVIDIA A100。

from IPython.display import clear_output

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'

Python 导入#

from flax import nnx
import sentencepiece as spm

要与 Gemma 模型交互,你将使用来自 google/flax GitHub 示例的 Flax NNX gemma 代码。由于它没有作为包公开,你需要使用以下解决方法从 GitHub 上的 Flax NNX examples/gemma 导入。

import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
  # Create a temporary directory and clone the `flax` repo.
  # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
  ! git clone https://github.com/google/flax.git {tmp}/flax
  sys.path.append(f"{tmp}/flax/examples/gemma")
  import params as params_lib
  import sampler as sampler_lib
  import transformer as transformer_lib
  sys.path.pop();
Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.
remote: Counting objects: 100% (605/605), done.
remote: Compressing objects: 100% (250/250), done.
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.

加载和准备 Gemma 模型#

首先,加载 Gemma 模型参数以供 Flax 使用。

params = params_lib.load_and_format_params(ckpt_path)

接下来,加载使用 SentencePiece 库构建的标记器文件。

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
True

然后,使用 Flax NNX gemma.transformer.TransformerConfig.from_params 函数,从检查点自动加载正确的配置。

注意:由于此版本中未使用令牌,词汇表大小小于输入嵌入的数量。

transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)

执行采样/推理#

基于你的模型和具有正确参数形状的标记器构建一个 Flax NNX gemma.Sampler

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
)

你已准备好开始采样!

注意:此 Flax NNX gemma.Sampler 使用 JAX 的 即时 (JIT) 编译,因此更改输入形状会触发重新编译,这会减慢速度。为了获得最快、最有效的结果,请保持批量大小一致。

input_batch 中编写提示并执行推理。随意调整 total_generation_steps(生成响应时执行的步数)。

input_batch = [
    "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
  ]

out_data = sampler(
    input_strings=input_batch,
    total_generation_steps=300,  # The number of steps performed when generating a response.
  )

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print()
  print(10*'#')
Prompt:

# Python program for implementation of Bubble Sort

def bubbleSort(arr):
Output:

    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(arr, i, j):
    temp = arr[i]
    arr[i] = arr[j]
    arr[j] = temp


# Driver code
arr = [5, 2, 8, 3, 1, 9]
print("Unsorted array:")
print(arr)
bubbleSort(arr)
print("Sorted array:")
print(arr)


# Time complexity of Bubble sort O(n^2)
# where n is the length of the array


# Space complexity of Bubble sort O(1)
# as it only requires constant extra space for the swap operation


# This program uses the bubble sort algorithm to sort the given array in ascending order.

```python
# This program uses the bubble sort algorithm to sort the given array in ascending order.

def bubbleSort(arr):
    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(

##########

你应该获得气泡排序算法的 Python 实现。