背景介绍
本文是自己 使用图片数据集训练深度神经网络实现图片分类 的一次学习记录。其中数据集是b站小土堆评论区的蚂蚁蜜蜂数据集,训练和验证数据集总共只有400张左右的图片。数据集非常小,其实并不适合进行深度学习的训练。我只是想通过这个数据集来学习自己写dataset类并训练模型的整个过程。模型搭建使用的是PyTorch框架,模型是自己随便搭的,很简单,只有卷积(CNN)、最大池化(MaxPooling)、线性层(Linear)。如果有uu也打算写代码练习,可以b站搜小土堆获取数据集哈!
写这篇文章,一来是对自己编写代码过程的一次总结反思,二来详细记录实现模型训练的整个过程,既加深自己的记忆,又方便以后回忆学习。如果,其中所记录的点有问题,还烦请朋友们不要吝啬文字,欢迎在评论区指出、讨论!如果这篇学习记录对您有所启发的话,还请点赞收藏啦,感谢!!!
dataset类的实现
如果想使用自己的数据集,通常需要自己实现数据集的类。实现思路:写一个数据集类,继承自torch.utils.data.Dataset
;在类里重写重写__init__、__getitem__、__len__方法
,可见下方代码
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os
import cv2
trans = transforms.Compose([
transforms.ToTensor(),
# transforms.CenterCrop(256)
])
# 数据集继承自Dataset,重写init、getitem、len方法
class AB(Dataset):
def __init__(self, root, label):
self.root = root
self.label = label
self.data_path = os.path.join(self.root, self.label)
self.data = os.listdir(s