Pytorch学习笔记一:制作、加载自己的图像数据集
文章目录
前言
首先介绍如何用pytorch加载网络现有数据集,然后介绍如何制作自己的图像数据集并批量读取来训练自己的网络。
提示:以下是本篇文章正文内容,下面案例可供参考
一、下载数据集
使用Pytorch进行读取本地的MINIST数据集并进行装载
# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
root="./data", # 下载数据,并且存放在data文件夹中
train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
download=True
)
testDataset = torchvision.datasets.MNIST(
root="./data",
train=False,
transform=transforms.ToTensor(),
download=True
)
二、加载自己的数据集
1.制作数据集
训练神经网络需要标准输入图像和它的真值标签。
在分类问题中,比如猫、狗、船、车等等,我们可以用数字代表不同的分类。可以制作一个txt文档用于存放输入图像的地址和它对应的标签数字。
我现在有个任务需要以图像作为输入,以另一张处理过后的图像作为它的真值,所以我在txt文本下面写的是它们的路径。在项目路径下新建了一个train文件夹用于放训练图片,并在train文件夹下新建一个训练的txt用于标注训练图像和标签图像
2.加载数据集
Dataset类
PyTorch读取图片,主要是通过Dataset类是Pytorch中所有数据集加载类中应该继承的父类。我们通过继承改写Dataset类来读取自己的图像数据集。其中以下三个函数必须改写:
__init__方法里面进行读取数据文件
__getitem__方法进行支持下标访问
__len__方法返回自定义数据集的大小,方便后期遍历
class OpticalSARDataset(Data.Dataset):
&