本文需要具备python编辑器和pytorch深度学习框架的语句基础知识
目录
前言
本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py、network.py、dataset.py以及train.py文件,功能是对本地的数据集进行分类。本文介绍逻辑是总分形式,即首先对总流程进行一个概括,然后分别介绍每个流程中的实现过程(代码+流程图+文字的介绍)。
对于整个项目的流程首先是加载本地数据集,然后导入Resnet网络,最后进行网络训练。整体来说一个完整的小项目,难度并不高,需要有一定的pytorch语句以及深度学习的基础。
mian.py文件是该项目的总文件,也是训练网络模型的运行文件,文本的介绍流程是随着该文件一 一对代码进行介绍。
main.py代码如下所示:
from dataset import data_dataloader #电脑本地写的读取数据的函数
from torch import nn #导入pytorch的nn模块
from torch import optim #导入pytorch的optim模块
from network import Res_net #电脑本地写的网络框架的函数
from train import train #电脑本地写的训练函数
def main():
# 以下是通过Data_dataloader函数输入为:数据的路径,数据模式,数据大小,batch的大小,有几线并用 (把dataset和Dataloader功能合在了一起)
train_loader = data_dataloader(data_path='./data', mode='train', size=64, batch_size=24, num_workers=4)
val_loader = data_dataloader(data_path='./data', mode='val', size=64, batch_size=24, num_workers=2)
test_loader = data_dataloader(data_path='./data', mode='test', size=64, batch_size=24, num_workers=2)
# 以下是超参数的定义
lr = 1e-4 #学习率
epochs = 10 #训练轮次
model = Res_net(2) # resnet网络
optimizer = optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.CrossEntropyLoss() # 损失函数
# 训练以及验证测试函数
train(model=model, optimizer=optimizer, loss_function=loss_function, train_data=train_loader, val_data=val_loader,test_data= test_loader, epochs=epochs)
if __name__ == '__main__':
main()
main.py流程图如图1所示:
图 1 main.py 代码流程图
一、dataset.py文件
main.py()前五行分别是导入相应的模块,其中dataset,network以及train是本地编写的文件。在mian()函数中的前几行代码中,我们使用dataset.py文件中的Data_dataloader函数导入训练集、验证集和测试集。Dataset文件是导入我们自己的本地数据库,其功能是得到所有的数据,将其变成pytorch能够识别的tensor数据,然后得到图片。
dataset.py文件代码如下所示:
import torch
import os,glob
import random
import csv
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
# 第一部分:通过三个步骤得到输出的tensor类型的数据
class Dataset_self(Dataset): #如果是nn.moduel 则是编写网络模型框架,这里需要继承的是dataset的数据,所以括号中的是Dataset
#第一步:初始化
def __init__(self,root,mode,resize,): #root是文件根目录,mode是选择什么样的数据集,resize是图像重新调整大小
super(Dataset_self, self).__init__()
self.resize = resize
self.root = root
self.name_label = {} #创建一个字典来保存每个文件的标签
#首先得到标签相对于的字典(标签和名称一一对应)
for name in sorted(os.listdir(os.path.join(root))): #排序并且用列表的形式打开文件夹
if not os.path.isdir(os.path.join(root,name)): #不是文件夹就不需要读取
continue
self.name_label[name] = len(self.name_label.keys()) #每个文件的名字为name_Label字典中有多少对键值对的个数
#print(self.name_label)
self.image,self.label = self.make_csv('images.csv') #编写一共函数来读取图片和标签的路径
#在得到image和label的基础上对图片数据进行一共划分 (注意:如果需要交叉验证就不需要验证集,只划分为训练集和测试集)
if mode == 'train':
self.image ,self.label= self.image[:int(0.6*len(self.image))],self.label[:int(0.6*len(self.label))]
if mode == 'val':
self.image ,self.label= self.image[int(0.6*len(self.image)):int(0.8*len(self.image))],self.label[int(0.6*len(self.label)):int(0.8*len(self.label))]
if mode == 'test':
self.image ,self.label= self.image[int(0.8*len(self.image)):],self.label[int(0.8*len(self.label)):]
# 获得图片和标签的函数
def make_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)): #如果不存在汇总的目录就新建一个
images = []
for image in self.name_label.keys(): # 让image到name_label中的每个文件中去读取图片
images += glob.glob(os.path.join(self.root,image,'*jpg')) #加* 贪婪搜索关于jpg的所有文件
#print('长度为:{},第二张图片为:{}'.format(len(images),images[1]))
random.shuffle(images) #把images列表中的数据洗牌
# images[0]: ./data\ants\382971067_0bfd33afe0.jpg
with open(os.path.join(self.root,filename),mode='w',newline='') as f : #创建文件
writer = csv.writer(f)
for image in