PyTorch加载数据初认识
1. Dataset
与Dataloader
1.1. Dataset
与Dataloader
理解
1.2. 三种数据集标签标注形式:
- 文件夹名字就是标签名
- 有另外一个文件记录下文件与标签的映射关系;
- 文件本身带有标签名、
1.3. 在Jupyter中查询使用方法:
help(Dataset)
Dataset??
2. 加载数据集
2.1. 数据集准备
-
将数据引入项目中:
2.2. input
与label
理解
2.3. python console
分布操作
由于需要使用PIL
库,执行命令安装PIL
:pip install pillow
说明:PIL(Python Imaging Library)
是一个用于图像处理的经典库
2.3.1 打开图片
python console
中执行以下命令打开图片:
from PIL import Image
img_path = "image path" # image path为图片地址,可以使用相对路径或绝对路径
img = Image.open(img_path) # 将图片地址对应的图片赋值给img变量
img.size # 查看图片的长宽
(768, 512)
img.show() # 打开图片
说明:
-
加载图片路径需要用双斜线
-
右侧显示图片的属性
2.3.2. 创建图片列表
由于加载图片首先需要获取图片的地址img_path
,可以将所有图片地址转换成list
列表,然后可以通过idx
索引到图片的地址
使用os.listdir()
创建图片地址列表:
import os
dir_path = "dataset/train/ants"
img_path_list = os.listdir(dir_path) # 把路径下的文件装入到list容器中,创建图片地址列表
img_path_list[0] # 索引到第一张图片
'0013035.jpg' # 输出结果
2.3.3. 拼接路径
使用os.path.join()
拼接路径:
import os
root_dir = "dataset/train" # 数据根目录
label_dir = "ants" # 图片文件夹为图片标签
path = os.path.join(root_dir, label_dir) # 将root_dir与label_dir两个路径拼接
说明:linux
与windows
不同系统下/
有着不同含义,使用os.path.join
可以根据不同系统自动拼接路径,不会出错
2.3.4. 通过索引idx
获取img
和label
import os
root_dir = "dataset/train"
label_dir = "ants"
path = os.path.join(root_dir, label_dir)
img_path = os.listdir(path) # 图片文件夹路径
idx = 0
img_name = img_path[idx] # 索引到图片名称
img_item_path = os.path.join(root_dir, label_dir, img_name) # 获取图片相对路径
from PIL import Image
img = Image.open(img_item_path)
label = label_dir
3. 完整代码展示
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 图片root地址
self.label_dir = label_dir # 图片label地址
self.path = os.path.join(self.root_dir, self.label_dir) # 拼接root与label地址,获取图片文件夹路径
self.img_path = os.listdir(self.path) # 获取所有图片的列表
def __getitem__(self, idx):
img_name = self.img_path[idx] # 获取图片名称
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 获取图片相对路径
img = Image.open(img_item_path) # 获取索引图片
label = self.label_dir # 获取图片label
return img, label # 需要返回img和label
def __len__(self):
return len(self.img_path) # 返回图片数据长度
root_dir = "dataset/train" # 数据集root目录
ants_label_dir = "ants" # 获取ants数据集label
bees_label_dir = "bees" # 获取bees数据集label
ants_dataset = MyData(root_dir, ants_label_dir) # 获取ants数据集实例
bees_dataset = MyData(root_dir, bees_label_dir) # 获取bees数据集实例
train_dataset = ants_dataset + bees_dataset # ants与bees数据集合并
img_ants, label_ants = ants_dataset[0] # 索引获取ants图片数据
print(img_ants, label_ants) # 输出获取内容
img_ants.show() # 打开图片
img_bees, label_bees = bees_dataset[0] # 索引获取bees图片数据
print(img_bees, label_bees) # 输出获取内容
img_bees.show() # 打开图片
print(len(train_dataset)) # 查看合并数据集长度
print(len(ants_dataset)) # 查看ants数据集长度
print(bees(ants_dataset)) # 查看bees数据集长度
说明:一个函数当中的局部变量不能传递给另一个函数使用,__init__
函数中变量使用self
可以将函数中指定的变量传递给后面的函数使用,相当于指定一个类中的全局变量
4. 其他常用数据集表达形式
将数据集分成image
文件夹和label
文件夹
image
文件夹存放数据集相应的图片
label
文件夹存放数据集对应的标签