Google Flax 框架快速入门:手写数字识别实战教程

Google Flax 框架快速入门:手写数字识别实战教程

flax Flax is a neural network library for JAX that is designed for flexibility. flax 项目地址: https://gitcode.com/gh_mirrors/fl/flax

前言

Google Flax 是一个基于 JAX 构建的神经网络库,它提供了简洁高效的 API 来构建和训练深度学习模型。本教程将带领读者从零开始,使用 Flax 框架实现一个经典的卷积神经网络(CNN)来完成 MNIST 手写数字识别任务。

环境准备

首先需要安装 Flax 库,推荐使用最新版本:

!pip install -q flax>=0.7.5

数据加载与预处理

在深度学习中,数据准备是第一步。我们使用 TensorFlow Datasets (TFDS) 加载 MNIST 数据集,并进行必要的预处理:

import tensorflow_datasets as tfds
import tensorflow as tf

def get_datasets(num_epochs, batch_size):
    """加载并预处理MNIST数据集"""
    # 加载原始数据集
    train_ds = tfds.load('mnist', split='train')
    test_ds = tfds.load('mnist', split='test')
    
    # 归一化像素值到[0,1]范围
    normalize = lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 255.,
        'label': sample['label']
    }
    train_ds = train_ds.map(normalize)
    test_ds = test_ds.map(normalize)
    
    # 训练集增强:重复、打乱、分批
    train_ds = train_ds.repeat(num_epochs).shuffle(1024)
    train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
    
    # 测试集处理
    test_ds = test_ds.shuffle(1024)
    test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
    
    return train_ds, test_ds

构建卷积神经网络

使用 Flax 的 Linen API 可以轻松构建神经网络模型。下面是一个典型的 CNN 结构,包含两个卷积层和两个全连接层:

from flax import linen as nn

class CNN(nn.Module):
    """简单的CNN模型"""
    
    @nn.compact
    def __call__(self, x):
        # 第一卷积层
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        # 第二卷积层
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        # 展平后接全连接层
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)  # 输出10个类别
        return x

模型可视化

Flax 提供了方便的模型可视化工具,可以查看各层结构和计算量:

cnn = CNN()
print(cnn.tabulate(jax.random.key(0), 
                   jnp.ones((1, 28, 28, 1)),
                   compute_flops=True))

训练状态管理

Flax 推荐使用 TrainState 来统一管理训练状态,包括模型参数、优化器和训练指标:

from flax.training import train_state
from flax import struct
import optax
from clu import metrics

@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output('loss')

class TrainState(train_state.TrainState):
    metrics: Metrics

def create_train_state(module, rng, learning_rate, momentum):
    """初始化训练状态"""
    params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return TrainState.create(
        apply_fn=module.apply, 
        params=params, 
        tx=tx,
        metrics=Metrics.empty())

训练过程实现

单步训练函数

使用 JAX 的自动微分和 JIT 编译优化训练过程:

@jax.jit
def train_step(state, batch):
    """单步训练"""
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch['label']).mean()
        return loss
    
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads)

指标计算函数

@jax.jit
def compute_metrics(*, state, batch):
    """计算评估指标"""
    logits = state.apply_fn({'params': state.params}, batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']).mean()
    
    metric_updates = state.metrics.single_from_model_output(
        logits=logits, labels=batch['label'], loss=loss)
    return state.replace(metrics=state.metrics.merge(metric_updates))

完整训练流程

参数设置

num_epochs = 10
batch_size = 32
learning_rate = 0.01
momentum = 0.9

初始化

# 确保可复现性
tf.random.set_seed(0)
init_rng = jax.random.key(0)

# 加载数据
train_ds, test_ds = get_datasets(num_epochs, batch_size)

# 初始化训练状态
state = create_train_state(cnn, init_rng, learning_rate, momentum)

训练循环

metrics_history = {'train_loss': [], 'train_accuracy': [],
                  'test_loss': [], 'test_accuracy': []}

num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs

for step, batch in enumerate(train_ds.as_numpy_iterator()):
    # 训练步骤
    state = train_step(state, batch)
    state = compute_metrics(state=state, batch=batch)
    
    # 每个epoch结束后评估
    if (step+1) % num_steps_per_epoch == 0:
        # 记录训练指标
        for metric, value in state.metrics.compute().items():
            metrics_history[f'train_{metric}'].append(value)
        state = state.replace(metrics=state.metrics.empty())
        
        # 计算测试集指标
        test_state = state
        for test_batch in test_ds.as_numpy_iterator():
            test_state = compute_metrics(state=test_state, batch=test_batch)
        
        # 记录测试指标
        for metric, value in test_state.metrics.compute().items():
            metrics_history[f'test_{metric}'].append(value)
        
        # 打印进度
        epoch = (step+1) // num_steps_per_epoch
        print(f"Epoch {epoch}: "
              f"Train Loss: {metrics_history['train_loss'][-1]:.4f}, "
              f"Accuracy: {metrics_history['train_accuracy'][-1]*100:.2f}% | "
              f"Test Loss: {metrics_history['test_loss'][-1]:.4f}, "
              f"Accuracy: {metrics_history['test_accuracy'][-1]*100:.2f}%")

结果可视化

训练完成后,我们可以绘制损失和准确率曲线:

import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
    ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}')
    ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}')
ax1.legend()
ax2.legend()
plt.show()

模型测试

最后,我们可以使用训练好的模型进行预测:

@jax.jit
def pred_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['image'])
    return logits.argmax(axis=1)

# 获取测试批次并预测
test_batch = test_ds.as_numpy_iterator().next()
predictions = pred_step(state, test_batch)

# 可视化预测结果
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
    ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
    ax.set_title(f"Pred: {predictions[i]}")
    ax.axis('off')

总结

通过本教程,我们完成了以下工作:

  1. 使用 Flax Linen API 构建了一个 CNN 模型
  2. 实现了完整的数据加载和预处理流程
  3. 使用 JAX 的自动微分和 JIT 编译优化训练过程
  4. 实现了训练和评估的完整流程
  5. 可视化训练过程和预测结果

Flax 提供了简洁高效的 API,结合 JAX 的强大功能,使得实现和优化深度学习模型变得更加容易。这个简单的 CNN 模型在 MNIST 数据集上可以达到约 99% 的准确率,展示了 Flax 框架的实用性和高效性。

flax Flax is a neural network library for JAX that is designed for flexibility. flax 项目地址: https://gitcode.com/gh_mirrors/fl/flax

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柏旦谊Free

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值