Torchvison-dataset的使用
这里介绍的时是Torchvision中关于数据库Dataset的一些使用方法。
首先我们可以在Pytorch观望中看到Torchvision中的很多数据集:
以CIFAR为例,点进去后可以了解到更多关于该数据集的一些信息:
在这里介绍了调用该数据库时的一些参数的设置及其功能。
Dataset的使用:
首先,我们需要导入torchvision库,为后面调用数据库提供库,同时导入SummaryWriter库,使用tersorboard可视化过程:
import torchvision
from torch.utils.tensorboard import SummaryWriter
我们定义train_set用于调用CIFAR10数据集:
torchvision.datasets.CIFAR10()
train_set = torchvision.datasets.CIFAR10(root='./dataset_learn',train=True,transform=dataset_transform,download=True)
①root=’'为该数据集保存的目录位置
②train为True代表该数据集用于训练,否则用于测试集
③transform为用于使用的将PIL图像数据转化为tensor类型的函数操作,可以自己定义操作内容(可以参照官方文档:transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop)
④download为True代表如果该数据集不存在的话,自动下载;数据集存在的话不下载
用一个for循环将数据集中的前10个样本展示在tersorboard上:
writer = SummaryWriter("p10")
for i in range(10):
img,target=train_set[i]
writer.add_image('dataset',img,i)
tensorboard上展示结果为:
完整实现代码: