加载数据集#

Open in Colab

用 Jax+Flax 编写的的神经网络期望其输入数据为 jax.numpy 数组实例。因此,从任何源加载数据集就像将其转换为 jax.numpy 类型并将其重新整形为网络的适当维度一样简单。

例如,本指南演示了如何使用 Torchvision、Tensorflow 和 Hugging Face 的 API 导入 MNIST。我们将把整个数据集加载到内存中。对于不适合内存的数据集,该过程类似,但应该以批处理方式进行。

MNIST 数据集包含 28x28 像素的手写数字灰度图像,并具有指定的 60k/10k 训练/测试分割。任务是预测每个图像的正确类别(数字 0,…,9)。

假设基于 CNN 的分类器,输入数据应该具有形状 (B, 28, 28, 1),其中尾随的单例维度表示灰度图像通道。

标签只是表示与图像相对应的数字的整数。因此,标签应该具有形状 (B,),以使用 optax.softmax_cross_entropy_with_integer_labels 进行损失计算。

import numpy as np
import jax.numpy as jnp

torchvision.datasets 加载#

import torchvision
def get_dataset_torch():
    mnist = {
        'train': torchvision.datasets.MNIST('./data', train=True, download=True),
        'test': torchvision.datasets.MNIST('./data', train=False, download=True)
    }

    ds = {}

    for split in ['train', 'test']:
        ds[split] = {
            'image': mnist[split].data.numpy(),
            'label': mnist[split].targets.numpy()
        }

        # cast from np to jnp and rescale the pixel values from [0,255] to [0,1]
        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

        # torchvision returns shape (B, 28, 28).
        # hence, append the trailing channel dimension.
        ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)

    return ds['train'], ds['test']
train, test = get_dataset_torch()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16

tensorflow_datasets 加载#

import tensorflow_datasets as tfds
def get_dataset_tf():
    mnist = tfds.builder('mnist')
    mnist.download_and_prepare()

    ds = {}

    for split in ['train', 'test']:
        ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1))

        # cast to jnp and rescale pixel values
        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

    return ds['train'], ds['test']
train, test = get_dataset_tf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16

从 🤗 Hugging Face datasets 加载#

#!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install
from datasets import load_dataset
def get_dataset_hf():
    mnist = load_dataset("mnist")

    ds = {}

    for split in ['train', 'test']:
        ds[split] = {
            'image': np.array([np.array(im) for im in mnist[split]['image']]),
            'label': np.array(mnist[split]['label'])
        }

        # cast to jnp and rescale pixel values
        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

        # append trailing channel dimension
        ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)

    return ds['train'], ds['test']
train, test = get_dataset_hf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16