一、DataLoader和DataSet
1.1 DataLoader
torch.utils.data.DataLoader #构建可迭代的数据装载器
DataLoader()参数 :
- dataset: Dataset类,决定数据从哪读取 及如何读取
- batchsize : 批大小
- num_works: 是否多进程读取数据
- shuffle: 每个epoch是否乱序
- drop_last:当样本数不能被batchsize整 除时,是否舍弃最后一批数据
- Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
- Iteration:一批样本输入到模型中,称之为一个Iteration
- Batchsize:批大小,决定一个Epoch有多少个Iteration
- 例如:样本总数:80, Batchsize:8 1 Epoch = 10 Iteration
1.2 DataSet
torch.utils.data.Dataset
Dataset抽象类,所有自定义的 Dataset需要继承它,并且复写__getitem__()
getitem :接收一个索引,返回一个样本
例如:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
#初始化。一般写些该类的全局变量,为后边函数提供变量
def __init__(self, data_dir, label_dir):
#不同的数据格式,此处的处理方法不同
self.data_dir = data_dir
self.label_dir = label_dir
self.path = os.path.join(self.data_dir, self.label_dir)
self.img_path = os.listdir(self.path) #获取路径中的每一张图片
def __getitem__(self, item):
img_name = self.img_path[item]
img_item_path = os.path.join(self.data_dir, self.label_dir, img_name) #获取图片路径
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
二、transforms
2.1 transforms
torchvision:计算机视觉工具包
torchvision.transforms : 常用的图像预处理方法——数据中心化 • 数据标准化 • 缩放 • 裁剪 • 旋转 • 翻转 • 填充 • 噪声添加 • 灰度变换 • 线性变换 • 仿射变换 • 亮度、饱和度及对比度变换
torchvision.datasets : 常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等
torchvision.model : 常用的模型预训练,AlexNet,VGG, ResNet,GoogLeNet等
transforms.Normalize(mean, std, inplace=False) #逐channel的对图像进行标准化为(-1,1)
- output = (input - mean) / std
- mean:各通道的均值
- std:各通道的标准差
- inplace:是否原地操作