数据集下载
MINST_PNG_Training在github的项目目录中的datasets中有MNIST的png格式数据集的压缩包
用于训练的神经网络模型

自定义数据集训练
在前文【Pytorch】13.搭建完整的CIFAR10模型我们已经知道了基本搭建神经网络的框架了,但是其中的数据集使用的torchvision中的CIFAR10官方数据集进行训练的
train_dataset = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,
transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,
transform=torchvision.transforms.ToTensor())

本文将用图片格式的数据集进行训练

我们通过
# Dataset CIFAR10
# Number of datapoints: 60000
# Root location: ../datasets
# Split: Train
# StandardTransform
# Transform: ToTensor()
print(train_dataset)
可以看到我们下载的数据集是这种格式的,所以我们的主要问题就是如何将自定义的数据集获取,并且转化为这种形式,剩下的步骤就和上文相同了
数据类型进行转化
我们的首要目的是,根据数据集的地址,分别将数据转化为train_dataset与test_dataset
我们需要调用ImageFolder方法来进行操作
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *
# 训练集地址
train_root = "../datasets/mnist_png/training"
# 测试集地址
test_root = '../datasets/mnist_png/testing'
# 进行数据的处理,定义数据转换
data_transform = transforms.Compose([transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor()])
# 加载数据集
train_dataset = ImageFolder(train_root

最低0.47元/天 解锁文章
4645

被折叠的 条评论
为什么被折叠?



