作者因为课题需求,刚接触Pytorch,这里是我的学习笔记分享,一方面作为笔记记录,加深印象,另一方面也是希望可以帮助大家。
课程来源:
PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
https://www.bilibili.com/video/BV1hE411t7RN?p=7&vd_source=1633322d160e1c2471aa2b3eb28bde0a
蚂蚁蜜蜂/练手数据集:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq
代码讲解:主要内容是数据集加载
导入package:
from torch.utils.data import Dataset
from PIL import Image
import os
创建了一个MyData类用于存放数据集,继承Dateset的类:
class MyData(Dataset):
在实际项目时,要构建自己的数据集,需要继承Dateset的类,Pytorch才能读取使用。想要实现这个类,必须要重写3个方法:init(self, 参数…)、 getitem(self, index)、len(self)。
1)进行初始化,加载相应的参数,self代表实例本身
def __init__(self,root_dir,label_dir):
2)两个全局变量,第一个是相对路径,第二个是标