PyTorch学习笔记(五)保存和加载和训练模型

本文介绍了在PyTorch中如何保存和加载模型,特别是在训练CNN时保存参数以便中断后恢复训练。此外,还展示了训练ResNet50模型进行人民币面值识别任务的通用代码,包括数据加载、模型构建、训练参数设置和GPU使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Environment

  • OS: Ubuntu 14.04
  • Python version: 3.7
  • PyTorch version: 1.4.0
  • IDE: PyCharm


0. 写在前面

训练一个 CNN 模型往往需要一些时间,那么在训练过程中需要适时地保存模型。如果手动或意外中断了训练,那么之后从某个 checkpoint 恢复训练,而不用从头开始。这里记一下如何保存、加载、训练模型。

1. 保存和加载

模型在硬盘中存储为 .pkl 的二进制格式。一般地,可以保存整个模型,但推荐的是保存模型的参数。

  • torch.save 函数实现模型的保存
import torch
import torchvision

model = torchvision.models.resnet50(pretrained=True)
torch.save(model, 'resnet50.pkl')  # 保存整个模型
torch.save(model.state_dict(), 'resnet50_state_dict.pkl')  # 保存模型参数


  • torch.load 函数实现模型的加载
import torch

model = torch.load('resnet50.pkl')  # 载入整个模型

若加载的是模型的参数,则需要构建一个模型,再调用 load_state_dict 方法载入参数

import torch
import torchvision

state_dict = torch.load('resnet50_state_dict.pkl')  # 载入模型参数
model = torchvision.models.resnet50(pretrained=False)
model.load_state_dict(state_dict)
# <All keys matched successfully>

2. 训练代码

训练一个 ResNet50 模型,完成 TinyMind人民币面值识别 任务。数据集文件夹中的目录结构见 PyTorch学习笔记(二)读取数据

├── data.py
├── models
├── rmbdata
└── train.py

以下为 train.py 中较为通用的代码

首先,导入需要的模块

import os
import time
import copy
import argparse

import torch
from torchvision.transforms import Compose, Resize, CenterCrop, RandomErasing, RandomGrayscale, ToTensor
from torch.utils.data import DataLoader
import torchvision
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

from data import RMBFaceValueDataset

接着,定义训练模型的函数

def train_model(model, criterion, optimizer, scheduler, num_epochs=36):
    """
    Train the model.

    Params:
    model: torch.nn.Module
        Neural network model
    criterion: torch.nn.Module
        Loss function
    optimizer: torch.optim.Optimizer
        Optimization strategy
    scheduler: torch.optim.lr_scheduler._LRScheduler
        Learning rate scheduler
    num_epochs: int
        Number of epochs for training

    Returns:
    model: torch.nn.Module
        Trained model
    """
    since = time.time()

    # initialize the best accuracy and its corresponding model
    best_acc = 0.0
    if use_gpu and len(device_ids) > 1:  # whether the model has been packed as Parallel
        best_model_wts = copy.deepcopy(model.module.state_dict())
    else:
        best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(start_epoch, num_epochs + 1):
        epoch_since = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs))

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            tic_batch = time.time()
            # Iterate over data
            for i, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs, labels = inputs.to(device), labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值