文章目录
前言
torchvision是Pytorch的计算机视觉工具库,是Pytorch专门用于处理图像的库。主要由3个子包组成,分别是:torchvision.datasets
、torchvision.models
、torchvision.transforms
。即目前流行的数据集,模型结构和常用的图片转换工具。
一、torchvision.datasets
包含很多常用视觉数据集,可以下载和加载:
每一个数据集的API都是基本相同的。他们都有两个相同的参数:transform和target_transform。
最经典的MNIST数据集API为例:
import torchvision
mydataset = torchvision.datasets.MNIST(root='./',
train=True,
transform=None,
target_transform=None,
download=True)
包含5个参数:
root:想要保存MNIST数据集的位置,如果download是Flase的话,则会从目标位置读取数据集;
download:True的话就会自动从网上下载这个数据集,到root的位置;
train:True的话,数据集下载的是训练数据集;False的话则下载测试数据集
transform:对图像进行处理的transform,比方说旋转平移缩放,输入的是PIL格式的图像(不是tensor矩阵);
target_transform:这个是对图像标签进行处理的函数(这个我没用过不太确定)
torchvision.datasets.ImageFolder
参考:https://blog.youkuaiyun.com/qq_39507748/article/details/105394808
dataset=torchvision.datasets.ImageFolder(
root, transform=None,
target_transform=None,
loader=<function default_loader>,
is_valid_file=None)
参数详解:
root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
loader:表示数据集加载方式,通常默认加载方式即可。
is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
返回的dataset都有以下三种属性:
self.classes:用一个 list 保存类别名称
self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
self.imgs:保存(img-path, class) tuple的 list
二、torchvision.models
torchvision提供了很多种预训练模型,大体分成四类:分别是分类模型,语义模型,目标检测模型和视频分类模型。
主要介绍分类的预训练模型:
构建模型可以通过下面的代码:
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
mobilenet = models.mobilenet_v2()
这样构建的模型的权重值是随机的,只有结构是保存的。想要获取预训练的模型,则需要设置参数pretrained:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
似乎这些模型的预训练数据集都是ImageNet,预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行normalization标准化。
例子
import torchvision.models as models
vgg16 = models.vgg16(pretrained = True) # 获取训练好的VGG16模型
pretrained_dict = vgg16.state_dict() # 返回包含模块所有状态的字典,包括参数和缓存
三、torchvision.transforms
transforms模块提供了一般的图像预处理方法, 例如
数据中心化
数据标准化
缩放
裁剪
旋转
翻转
填充
噪声添加
灰度变换
线性变换
仿射变换
亮度
饱满度及对比度变换
…
这些方法可以用于对图像的数据增强,又称为数据增广,是对训练集进行变换,使训练集更加丰富,从而使模型具有泛化能力。
函数 | 作用 |
---|---|
CenterCrop | 从图像中心裁剪图像 |
RandomCrop | 从图片中随即裁剪出给定尺寸的图片(可填充) |
RandomHorziontalFlip | 依概率水平翻转图片 |
RandomVerticalFlip | 依概率垂直翻转图片 |
RandomRotation | 随机旋转图片 |
Resize | 修改图像分辨率 |
ColorJitter | 调整亮度,对比度,饱和度和色相 |
Totensor | 转化为张量 |
ToPILImage | 将ndarray或者张良转化为PIL Image类型数据 |
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}