本文记录了一个简单的基于pytorch的图像多分类器模型构造过程,参考自Pytorch官方文档、磐创团队的《PyTorch官方教程中文版》以及余霆嵩的《PyTorch 模型训练实用教程》。从加载数据集开始,包括了模型设计、训练、测试等过程。
pytorch中文网:https://www.pytorchtutorial.com/
pytorch官方文档:https://pytorch.org/docs/stable/index.html
一. 加载数据
Pytorch的数据加载一般是用torch.utils.data.Dataset
与torch.utils.data.Dataloader
两个类联合进行。我们需要继承Dataset来定义自己的数据集类,然后在训练时用Dataloader加载自定义的数据集类。
1. 继承Dataset类并重写关键方法
pytorch的dataset类有两种:Map-style datasets和Iterable-style datasets。前者是我们常用的结构,而后者是当数据集难以(或不可能)进行随机读取时使用。在这里我们实现Map-style dataset。
继承torch.utils.data.Dataset
后,需要重写的方法有:__len__
与__getitem__
方法,其中__len__
方法需要返回所有数据的数量,而__getitem__
则是要依照给出的数据索引获取对应的tensor类型的Sample,除了这两个方法以外,一般还需要实现__init__
方法来初始化一些变量。话不多说,直接上代码。
'''
包括了各种数据集的读取处理,以及图像相关处理方法
'''
from torch.utils.data import Dataset
import torch
import os
import cv2
from Config import mycfg
import random
import numpy as np
class ImageClassifyDataset(Dataset):
def __init__(self, imagedir, labelfile, classify_num, train=True):
'''
这里进行一些初始化操作。
'''
self.imagedir = imagedir
self.labelfile = labelfile
self.classify_num = classify_num
self.img_list = []
# 读取标签
with open(self.labelfile, 'r') as fp:
lines = fp.readlines()
for line in lines:
filepath = os.path.join(self.imagedir, line.split(";")[0].replace('\\', '/'))
label = line.split(";")[1].strip(