数据预处理阶段,dataset.py中主要定义了DogCat类,并定义了__init__、__getitem__、__len__三个类方法,话不多说直接上代码,具体解读都在每一行后面的注释中。
# coding:utf8
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
"""
#os.listdir() 方法 : 返回指定文件夹包含的文件或文件夹名字的列表。该列表顺序以字母排序。
#os.path.join() : 将多个路径组合后返回
# 即imgs中存储的是路径+文件名的列表,如下
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('\\')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
#划分训练、验证集,验证:训练 = 3:7
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7 * imgs_num)]
else:
self.imgs = imgs[int(0.7 * imgs_num):]
if transforms is None:
#对图片数据进行转换操作,测试、验证和训练的数据转换有所区别
#标准化,即减均值,除以标准差,1*3对应RGB三通道,这组数据是imagenet训练集中抽样算出来的
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
#测试集和验证集
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224), #缩放Image,长宽比不变,最短边为224像素
T.CenterCrop(224), #从图片中间切出224*224的图片
T.ToTensor(), #将图片转成Tensor并归一化至[0,1]
normalize #标准化
])
#训练集
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224), #先随机采集,然后对裁剪得到的图像缩放为同一大小
T.RandomHorizontalFlip(), #以给定的概率p随机水平旋转PIL的图像,p默认值为0.5
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
一次返回一张图片的数据
如果是测试集,没有图片id,如1000.jpg返回1000
"""
img_path = self.imgs[index] #第index张图片的路径+文件名
if self.test:
label = int(self.imgs[index].split('.')[-2].split('\\')[-1])
else:
label = 1 if 'dog' in img_path.split('\\')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data) #调用init中的transform处理图片
return data, label
def __len__(self):
'''
返回数据集中所有图片的个数
'''
return len(self.imgs)
附上我看代码时参考的别人的博客: