1 数据集的组织形式
train训练集 val验证集
1.1第一种组织形式
文件夹名就是label,里面放的就是相应label的图片
1.2第二种组织形式
图片和label分开文件夹存放,相应文件名要一致,txt文件里存放相应图片的label
1.3第三种组织形式
label直接为图片的名称
2 pytorch读取数据涉及两个类:Dataset & Dataloader
Dataset:提供一种方式,获取需要的数据和对应的label 值,并完成编号。主要实现两个功能:
- 获取每一个数据及其对应label
- 统计数据集中的数据数量(神经网络经常需要对一个数据迭代多次,只有知道当前有多少数据,进行训练时才知道要训练多少次,才能把整个数据集迭代完)
Dataloader:打包(batch_size),为后面的神经网络提供不同的数据形式
2.1 Dataset类
数据的第一种组织形式
1、查看官方文档:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
该类是一个抽象类,所有的数据集想要在数据与标签之间建立映射都需要继承这个类。所有的子类都需要重写__getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label。子类也可以重写__len__方法,返回数据集的size大小。
2、导入Dataset类
from torch.utils.data import Dataset
3、创建一个类(需要继承Dataset类),类名为MyData
class MyData(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
在实现MyData类之前,首先需要读取图像数据
4、读取图像数据
- 方法一:使用PIL
- 方法二:使用OpenCV
使用PIL:
from PIL import Image
img_path = "dataset/train/ants/0013035.jpg"
img = Image.open(img_path)
img.show()
可以在Python控制台看到读取出来的 img,是一个JpegImageFile类的对象,展开还可以查看这个对象的属性,比如size。
使用OpenCV:
import cv2
img_path = "dataset/train/ants/0013035.jpg"
cv_img = cv2.imread(img_path)
cv2.imshow('cv_img',cv_img)
cv2.waitKey(0) ##等待时间,毫秒级,0表示任意键终止
cv2.destroyAllWindows()
可以在Python控制台看到读取出来的cv_img,是一个ndarray类的对象。
5、获取图片的文件名
借助os模块,从数据集路径中,获取所有文件的文件名,存储在一个列表中。
import os
dir_path = "dataset/train/ants"
img_path_list = os.listdir(dir_path)
可以使用索引去访问列表中的每个文件名
6、继续完善MyData
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir #根路径 如“dataset/train”
self.label_dir = label_dir #子路径 如“ants”
self.path = os.path.join(self.root_dir,self.label_dir) #将路径拼接起来
self.img_path = os.listdir(self.path) #该路径下的所有图片名变成一个列表
def __getitem__(self, index):
img_name = self.img_path[index] #根据索引从列表中获取了图片名
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) #组装路径,获得了图片的具体路径
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
- __init__():接收实例化时传入的参数(根路径和子路径),然后将两个路径进行组合,就得到了目标数据集的路径。将这个路径作为参数传入listdir函数,从而让img_path中存储该目录下所有文件名(通过索引就可以轻松获取每个文件名)。
- __getitem__():index是索引编号,用于对初始化中得到的文件名列表进行索引访问(这样就得到了具体的文件名),然后与根路径和子路径再次组合,得到具体数据(图片)的相对路径。然后用PIL读取具体的图像。返回的img是JpegImageFile对象,label是字符串。
- __len__():获取数据集的长度,在初始化中已经获取了所有文件名的列表,只需知道这个列表的长度。
7、使用MyData
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
8、完整代码
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, index):
img_name = self.img_path[index]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
8、补充
- 获取数据集中指定的数据
ants_dataset[0]
img,label = ants_dataset[0]
img.show()
- 拼接数据集
用处:有时候数据集不足就需要仿造数据集,需要将仿造数据集和真实数据集结合后训练
train_dataset = ants_dataset + bees_dataset
变成数据集的第二种组织形式:当label比较复杂,存储数据比较多时,是以每张图片对应一个txt文件,txt里存储label信息的方式。
为数据集中的每张图片生成对应的txt文件:
import os root_dir = 'dataset-练习/train' target_dir = 'ants_image' img_path = os.listdir(os.path.join(root_dir, target_dir)) label = target_dir.split('_')[0] out_dir = 'ants_label' for i in img_path: file_name = i.split('.jpg')[0] with open(os.path.join(root_dir, out_dir,"{}.txt".format(file_name)),'w') as f: f.write(label)