Determined AI 项目教程:使用 PyTorch 实现 MNIST 手写数字识别

Determined AI 项目教程:使用 PyTorch 实现 MNIST 手写数字识别

determined Determined is an open-source machine learning platform that simplifies distributed training, hyperparameter tuning, experiment tracking, and resource management. Works with PyTorch and TensorFlow. determined 项目地址: https://gitcode.com/gh_mirrors/de/determined

前言

本教程将带领读者学习如何将一个现有的 PyTorch 模型迁移到 Determined AI 平台上。我们将以经典的 MNIST 手写数字识别任务为例,展示如何将一个基础的 PyTorch 模型转换为 Determined 的试验(Trial)类,并利用 Determined 提供的强大功能进行模型训练和管理。

模型迁移概述

将 PyTorch 模型迁移到 Determined 平台的主要优势在于可以立即获得分布式训练、超参数搜索等高级功能,而无需修改模型的核心逻辑。Determined 会自动处理训练循环、检查点保存、日志管理等繁琐但重要的工作。

迁移过程主要涉及创建一个继承自 determined.pytorch.PyTorchTrial 的试验类,并实现以下关键方法:

  1. 模型、优化器和学习率调度器的初始化
  2. 定义训练批次的前向传播和反向传播
  3. 定义验证批次的评估逻辑
  4. 加载训练数据集
  5. 加载验证数据集

准备工作

在开始之前,请确保:

  1. 已经安装并配置好 Determined 集群
  2. 本地机器上安装了 Determined CLI
  3. 设置好 DET_MASTER 环境变量指向 Determined 主节点

创建 PyTorchTrial 类

基础结构

试验类的基本框架如下:

import torch.nn as nn
from determined.pytorch import DataLoader, PyTorchTrial, PyTorchTrialContext

class MNISTTrial(PyTorchTrial):
    def __init__(self, context: PyTorchTrialContext):
        # 初始化模型、优化器等
        pass
    
    def train_batch(self, batch, epoch_idx, batch_idx):
        # 训练批次处理
        pass
    
    def evaluate_batch(self, batch):
        # 验证批次处理
        pass
    
    def build_training_data_loader(self):
        # 构建训练数据加载器
        pass
    
    def build_validation_data_loader(self):
        # 构建验证数据加载器
        pass

初始化方法详解

__init__ 方法中,我们需要:

  1. 存储试验上下文供后续使用
  2. 创建模型并调用 wrap_model 进行包装
  3. 创建优化器并调用 wrap_optimizer 进行包装
def __init__(self, context: PyTorchTrialContext):
    self.context = context
    
    # 为分布式训练创建唯一的数据下载目录
    self.download_directory = f"/tmp/data-rank{self.context.distributed.get_rank()}"
    self.data_downloaded = False
    
    # 构建并包装模型
    self.model = self.context.wrap_model(
        nn.Sequential(
            nn.Conv2d(1, self.context.get_hparam("n_filters1"), 3, 1),
            nn.ReLU(),
            # 更多层...
            nn.LogSoftmax(),
        )
    )
    
    # 构建并包装优化器
    self.optimizer = self.context.wrap_optimizer(
        torch.optim.Adadelta(
            self.model.parameters(), 
            lr=self.context.get_hparam("learning_rate")
        )
    )

数据加载方法

Determined 使用 build_training_data_loaderbuild_validation_data_loader 方法加载数据集,这两个方法都应返回一个 DataLoader 对象。

def build_training_data_loader(self):
    if not self.data_downloaded:
        self.download_directory = data.download_dataset(
            download_directory=self.download_directory,
            data_config=self.context.get_data_config(),
        )
        self.data_downloaded = True
    
    train_data = data.get_dataset(self.download_directory, train=True)
    return DataLoader(train_data, batch_size=self.context.get_per_slot_batch_size())

def build_validation_data_loader(self):
    # 类似训练数据加载方法
    validation_data = data.get_dataset(self.download_directory, train=False)
    return DataLoader(validation_data, batch_size=self.context.get_per_slot_batch_size())

训练和评估方法

train_batch 方法处理单个训练批次,包括前向传播、反向传播和优化器更新:

def train_batch(self, batch, epoch_idx, batch_idx):
    data, labels = batch
    
    # 前向传播
    output = self.model(data)
    loss = torch.nn.functional.nll_loss(output, labels)
    
    # 反向传播和优化
    self.context.backward(loss)
    self.context.step_optimizer(self.optimizer)
    
    return {"loss": loss}

evaluate_batch 方法处理单个验证批次,计算评估指标:

def evaluate_batch(self, batch):
    data, labels = batch
    
    output = self.model(data)
    validation_loss = torch.nn.functional.nll_loss(output, labels).item()
    
    pred = output.argmax(dim=1, keepdim=True)
    accuracy = pred.eq(labels.view_as(pred)).sum().item() / len(data)
    
    return {"validation_loss": validation_loss, "accuracy": accuracy}

配置和运行实验

实验配置文件

Determined 使用 YAML 文件配置实验参数:

name: mnist_pytorch_const
hyperparameters:
  learning_rate: 1.0
  global_batch_size: 64
  n_filters1: 32
  n_filters2: 64
  dropout1: 0.25
  dropout2: 0.5
searcher:
  name: single
  metric: validation_loss
  smaller_is_better: true
entrypoint: python3 train.py --epochs 1

启动实验

使用 Determined CLI 启动实验:

det experiment create const.yaml .

其中 const.yaml 是配置文件,. 表示当前目录包含模型代码。

模型评估和监控

Determined 会自动记录训练和验证指标,可以通过 WebUI 查看:

  1. 在浏览器中访问 Determined 主节点地址
  2. 使用实验 ID 或描述查找你的实验
  3. 查看训练曲线、指标变化等信息

总结与进阶

通过本教程,我们学习了如何将 PyTorch 模型迁移到 Determined 平台。迁移后的模型可以立即利用 Determined 提供的分布式训练、超参数搜索等高级功能。接下来,你可以:

  1. 尝试修改模型结构或超参数
  2. 探索 Determined 的分布式训练功能
  3. 学习如何使用 Determined 进行超参数搜索
  4. 研究模型检查点和恢复训练功能

Determined 的强大功能可以帮助你更高效地进行深度学习实验,将更多精力集中在模型创新而非工程实现上。

determined Determined is an open-source machine learning platform that simplifies distributed training, hyperparameter tuning, experiment tracking, and resource management. Works with PyTorch and TensorFlow. determined 项目地址: https://gitcode.com/gh_mirrors/de/determined

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

左松钦Travis

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值