PyTorch | torchvision.datasets/models/transforms


前言

torchvision是Pytorch的计算机视觉工具库,是Pytorch专门用于处理图像的库。主要由3个子包组成,分别是:torchvision.datasetstorchvision.modelstorchvision.transforms。即目前流行的数据集,模型结构和常用的图片转换工具。

一、torchvision.datasets

包含很多常用视觉数据集,可以下载和加载:
在这里插入图片描述
每一个数据集的API都是基本相同的。他们都有两个相同的参数:transform和target_transform。

最经典的MNIST数据集API为例:

import torchvision
mydataset = torchvision.datasets.MNIST(root='./',
                                      train=True,
                                      transform=None,
                                      target_transform=None,
                                      download=True)

包含5个参数:

root:想要保存MNIST数据集的位置,如果download是Flase的话,则会从目标位置读取数据集;
download:True的话就会自动从网上下载这个数据集,到root的位置;
train:True的话,数据集下载的是训练数据集;False的话则下载测试数据集
transform:对图像进行处理的transform,比方说旋转平移缩放,输入的是PIL格式的图像(不是tensor矩阵);
target_transform:这个是对图像标签进行处理的函数(这个我没用过不太确定)

torchvision.datasets.ImageFolder

参考:https://blog.youkuaiyun.com/qq_39507748/article/details/105394808

dataset=torchvision.datasets.ImageFolder(
                       root, transform=None, 
                       target_transform=None, 
                       loader=<function default_loader>, 
                       is_valid_file=None)
参数详解:
    root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
    transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
    target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
    loader:表示数据集加载方式,通常默认加载方式即可。
    is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件) 
返回的dataset都有以下三种属性:
    self.classes:用一个 list 保存类别名称
    self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
    self.imgs:保存(img-path, class) tuplelist                 

二、torchvision.models

torchvision提供了很多种预训练模型,大体分成四类:分别是分类模型,语义模型,目标检测模型和视频分类模型。
在这里插入图片描述
主要介绍分类的预训练模型:
在这里插入图片描述
构建模型可以通过下面的代码:

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
mobilenet = models.mobilenet_v2()

这样构建的模型的权重值是随机的,只有结构是保存的。想要获取预训练的模型,则需要设置参数pretrained:

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)

似乎这些模型的预训练数据集都是ImageNet,预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行normalization标准化。

例子

import torchvision.models as models

vgg16 = models.vgg16(pretrained = True) # 获取训练好的VGG16模型
pretrained_dict = vgg16.state_dict() # 返回包含模块所有状态的字典,包括参数和缓存

三、torchvision.transforms

transforms模块提供了一般的图像预处理方法, 例如

数据中心化
数据标准化
缩放
裁剪
旋转
翻转
填充
噪声添加
灰度变换
线性变换
仿射变换
亮度
饱满度及对比度变换
…

这些方法可以用于对图像的数据增强,又称为数据增广,是对训练集进行变换,使训练集更加丰富,从而使模型具有泛化能力。

函数作用
CenterCrop从图像中心裁剪图像
RandomCrop从图片中随即裁剪出给定尺寸的图片(可填充)
RandomHorziontalFlip依概率水平翻转图片
RandomVerticalFlip依概率垂直翻转图片
RandomRotation随机旋转图片
Resize修改图像分辨率
ColorJitter调整亮度,对比度,饱和度和色相
Totensor转化为张量
ToPILImage将ndarray或者张良转化为PIL Image类型数据
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值