Environment
- OS: Ubuntu 14.04
- Python version: 3.7
- PyTorch version: 1.4.0
- IDE: PyCharm
0. 写在前面
训练一个 CNN 模型往往需要一些时间,那么在训练过程中需要适时地保存模型。如果手动或意外中断了训练,那么之后从某个 checkpoint 恢复训练,而不用从头开始。这里记一下如何保存、加载、训练模型。
1. 保存和加载
模型在硬盘中存储为 .pkl 的二进制格式。一般地,可以保存整个模型,但推荐的是保存模型的参数。
torch.save
函数实现模型的保存
import torch
import torchvision
model = torchvision.models.resnet50(pretrained=True)
torch.save(model, 'resnet50.pkl') # 保存整个模型
torch.save(model.state_dict(), 'resnet50_state_dict.pkl') # 保存模型参数
torch.load
函数实现模型的加载
import torch
model = torch.load('resnet50.pkl') # 载入整个模型
若加载的是模型的参数,则需要构建一个模型,再调用 load_state_dict
方法载入参数
import torch
import torchvision
state_dict = torch.load('resnet50_state_dict.pkl') # 载入模型参数
model = torchvision.models.resnet50(pretrained=False)
model.load_state_dict(state_dict)
# <All keys matched successfully>
2. 训练代码
训练一个 ResNet50 模型,完成 TinyMind人民币面值识别 任务。数据集文件夹中的目录结构见 PyTorch学习笔记(二)读取数据。
├── data.py
├── models
├── rmbdata
└── train.py
以下为 train.py 中较为通用的代码
首先,导入需要的模块
import os
import time
import copy
import argparse
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, RandomErasing, RandomGrayscale, ToTensor
from torch.utils.data import DataLoader
import torchvision
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from data import RMBFaceValueDataset
接着,定义训练模型的函数
def train_model(model, criterion, optimizer, scheduler, num_epochs=36):
"""
Train the model.
Params:
model: torch.nn.Module
Neural network model
criterion: torch.nn.Module
Loss function
optimizer: torch.optim.Optimizer
Optimization strategy
scheduler: torch.optim.lr_scheduler._LRScheduler
Learning rate scheduler
num_epochs: int
Number of epochs for training
Returns:
model: torch.nn.Module
Trained model
"""
since = time.time()
# initialize the best accuracy and its corresponding model
best_acc = 0.0
if use_gpu and len(device_ids) > 1: # whether the model has been packed as Parallel
best_model_wts = copy.deepcopy(model.module.state_dict())
else:
best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(start_epoch, num_epochs + 1):
epoch_since = time.time()
print('Epoch {}/{}'.format(epoch, num_epochs))
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
tic_batch = time.time()
# Iterate over data
for i, (inputs, labels) in enumerate(dataloaders[phase]):
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()