目录
1.数据集的管理说明
pytorch使用Dataset来管理训练和测试数据集,前文说过
torchvision.datasets.MNIST
这些 torchvision.datasets里面的数据集都是继承Dataset而来,对Datasetd 管理使用DataLoader,我们使用的的时候,只需要把Dataset类放在DataLoader这个容器里面,在训练的时候 for循环从DataLoader容器里面取出批次的数据,对模型进行训练。
2.数据集Dataset类说明
我们可以继承Dataset类,对训练和测试数据进行管理,继承Dataset示例:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import cv2
#继承from torch.utils.data import Dataset
class CDataSet(Dataset):
def __init__(self,path):
self.path = path
self.list = os.listdir(path)
self.len = len(self.list)
self.name = ['cloudy','rain','shine','sunrise']
self.trans = transforms.ToTensor()
def __len__(self):
return self.len
def __getitem__(self, item):
self.imgpath = os.path.join(self.path,self.list[item])
print(self.imgpath)
img = cv2.imread(self.imgpath)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cv2.resize(img,(100,100))
img = self.trans(img)
for i,n in enumerate(self.name):
if n in self.imgpath:
label = i+1
break
return img,label
ds = CDataSet(r'E:\test\pythonProject\dataset\cloudy')
dl = DataLoader(ds,batch_size=16,shuffle=True)
print(len(ds))
print(len(dl))
print(type(ds))
print(type(dl))
print(next(iter(dl)))
'''
D:\anaconda3\python.exe E:\test\pythonProject\test.py
300
19
<class '__main__.CDataSet'>
<class 'torch.utils.data.dataloader.DataLoader'>
E:\test\pythonProject\dataset\cloudy\cloudy294.jpg
E:\test\pythonProject\dataset\cloudy\cloudy156.jpg
E:\test\pythonProject\dataset\cloudy\cloudy149.jpg
E:\test\pythonProject\dataset\cloudy\cloudy148.jpg
E:\test\pythonProject\dataset\cloudy\cloudy3.jpg
E:\test\pythonProject\dataset\cloudy\cloudy106.jpg
E:\test\pythonProject\dataset\cloudy\cloudy137.jpg
E:\test\pythonProject\dataset\cloudy\cloudy276.jpg
E:\test\pythonProject\dataset\cloudy\cloudy147.jpg
E:\test\pythonProject\dataset\cloudy\cloudy8.jpg
E:\test\pythonProject\dataset\cloudy\cloudy164.jpg
E:\test\pythonProject\dataset\cloudy\cloudy293.jpg
E:\test\pythonProject\dataset\cloudy\cloudy116.jpg
E:\test\pythonProject\dataset\cloudy\cloudy56.jpg
E:\test\pythonProject\dataset\cloudy\cloudy187.jpg
E:\test\pythonProject\dataset\cloudy\cloudy177.jpg
[tensor([[[[0.2235, 0.2471, 0.3569, ..., 0.1490, 0.1373, 0.1373],
[0.2902, 0.4039, 0.4078, ..., 0.1529, 0.1373, 0.1294],
[0.3294, 0.4941, 0.4000, ..., 0.1529, 0.1333, 0.1137],
...,
[0.0118, 0.0118, 0.0118, ..., 0.0078, 0.0078, 0.0078],
[0.0118, 0.0118, 0.011