示例:使用预训练的 Gemma 和 Flax NNX 进行推理#
此示例演示如何使用 Flax NNX 加载 Gemma 开源模型文件,并使用它们执行采样/推理以生成文本。你将使用用 Flax 和 JAX 编写的 Flax NNX gemma
模块进行模型参数配置和推理。
建议使用可以访问 A100 GPU 加速的 Google Colab 来运行代码。
安装#
安装必要的依赖项,包括 kagglehub
。
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope
下载模型#
要使用 Gemma 模型,你需要一个 Kaggle 账户和 API 密钥
要创建账户,请访问 Kaggle 并单击“注册”。
如果你有账户(或创建账户后),你需要登录,转到你的 “设置”,然后在“API”下单击“创建新令牌”以生成并下载你的 Kaggle API 密钥。
在 Google Colab 中,在“密钥”下添加你的 Kaggle 用户名和 API 密钥,将用户名存储为
KAGGLE_USERNAME
,密钥存储为KAGGLE_KEY
。如果你使用的是 Kaggle Notebook 进行免费 TPU 或其他硬件加速,它在“加载项” > “密钥”下有一个密钥存储功能,以及访问存储密钥的说明。
然后运行下面的单元格。
import kagglehub
kagglehub.login()
如果一切顺利,应该显示 Kaggle 凭据已设置。 Kaggle 凭据已成功验证。
。
注意:在 Google Colab 中,你也可以按照上面的可选步骤 3,使用下面的代码验证 Kaggle 身份。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
现在,加载你要尝试的 Gemma 模型。下一个单元格中的代码使用 kagglehub.model_download
下载模型文件。
注意:对于较大的模型,例如 gemma 7b
和 gemma 7b-it
(指令),你可能需要具有大量内存的硬件加速器,例如 NVIDIA A100。
from IPython.display import clear_output
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'
Python 导入#
from flax import nnx
import sentencepiece as spm
要与 Gemma 模型交互,你将使用来自 google/flax
GitHub 示例的 Flax NNX gemma
代码。由于它没有作为包公开,你需要使用以下解决方法从 GitHub 上的 Flax NNX examples/gemma
导入。
import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
# Create a temporary directory and clone the `flax` repo.
# Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.
remote: Counting objects: 100% (605/605), done.
remote: Compressing objects: 100% (250/250), done.
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.
加载和准备 Gemma 模型#
首先,加载 Gemma 模型参数以供 Flax 使用。
params = params_lib.load_and_format_params(ckpt_path)
接下来,加载使用 SentencePiece 库构建的标记器文件。
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
True
然后,使用 Flax NNX gemma.transformer.TransformerConfig.from_params
函数,从检查点自动加载正确的配置。
注意:由于此版本中未使用令牌,词汇表大小小于输入嵌入的数量。
transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)
执行采样/推理#
基于你的模型和具有正确参数形状的标记器构建一个 Flax NNX gemma.Sampler
。
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
)
你已准备好开始采样!
注意:此 Flax NNX gemma.Sampler
使用 JAX 的 即时 (JIT) 编译,因此更改输入形状会触发重新编译,这会减慢速度。为了获得最快、最有效的结果,请保持批量大小一致。
在 input_batch
中编写提示并执行推理。随意调整 total_generation_steps
(生成响应时执行的步数)。
input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]
out_data = sampler(
input_strings=input_batch,
total_generation_steps=300, # The number of steps performed when generating a response.
)
for input_string, out_string in zip(input_batch, out_data.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
print()
print(10*'#')
Prompt:
# Python program for implementation of Bubble Sort
def bubbleSort(arr):
Output:
for i in range(len(arr)):
for j in range(len(arr) - i - 1):
if arr[j] > arr[j + 1]:
swap(arr, j, j + 1)
def swap(arr, i, j):
temp = arr[i]
arr[i] = arr[j]
arr[j] = temp
# Driver code
arr = [5, 2, 8, 3, 1, 9]
print("Unsorted array:")
print(arr)
bubbleSort(arr)
print("Sorted array:")
print(arr)
# Time complexity of Bubble sort O(n^2)
# where n is the length of the array
# Space complexity of Bubble sort O(1)
# as it only requires constant extra space for the swap operation
# This program uses the bubble sort algorithm to sort the given array in ascending order.
```python
# This program uses the bubble sort algorithm to sort the given array in ascending order.
def bubbleSort(arr):
for i in range(len(arr)):
for j in range(len(arr) - i - 1):
if arr[j] > arr[j + 1]:
swap(arr, j, j + 1)
def swap(
##########
你应该获得气泡排序算法的 Python 实现。