PyTorch读取图片数据

本文介绍使用PyTorch加载MNIST数据集的方法,包括数据下载、转换、加载及可视化过程。通过具体代码实例展示如何批量获取数据并显示部分样本。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

CNN实战基础——读取图片数据
实现结果如下图:
在这里插入图片描述

取数据

1.导包

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

2.下载数据

trans = transforms.ToTensor()
train_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=False, transform=trans, download=True)

root='xx' :表示 数据集所在路径,数据集不用解压
train=True or False:表明是 训练集或测试集
transform=trans :把读取的Img类型图片转为Tensor
down=True :表示 若数据集在root路径下,则直接加载;若不在root里,则下载(外网很慢,可以提前下载好放进去)

下图为下载之后,自动创建了raw文件夹,数据集在raw里
在这里插入图片描述

3.加载数据

train_dl = DataLoader(train_dataset, batch_size=32)

param1:指明加载什么数据集
param2:一批有32张图

4.打印数据
我们往往想看看第一批(32张)图,但不能通过train_dl[0]等下标方式访问,下面有两种方式都可以:

4.1 next + enumerate

examples = enumerate(train_dl) # 返回数字下标+迭代器(可用next访问)
next(examples)

4.2 next + iter

examples = iter(train_dl) # 返回迭代器,无下标
imgs, labels = next(examples)

next()\iter()函数

结果可视化

第一版

import matplotlib.pyplot as plt
fig = plt.figure() # 创建一个窗口
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.imshow(imgs[i][0])
  plt.title('label:{}'.format(labels[i]))# 格式化输出,否则是tensor数据

plt.imshow(imgs[i][0])这里指:imgs是4维数据,后面2维是28*28图片大小,不用管,关键是用前两维定位到图片
结果:
在这里插入图片描述
发现布局有问题,故增加plt.tight_layout()语句,结果明显分开

在这里插入图片描述
发现x,y轴的数字可以不要,添加plt.xticks([]), plt.yticks([]),结果:
在这里插入图片描述

想要灰度图,设置一个参数即可plt.imshow(imgs[i][0],cmap='gray')
在这里插入图片描述

完整代码:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

trans = transforms.ToTensor()
train_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(
    root='./drive/MyDrive/DataSet', train=False, transform=trans, download=True)

train_dl = DataLoader(train_dataset, batch_size=32)

# train_dl里有很多batch,每个batch里有batch_size张图片
#examples = enumerate(train_dl)
#next(examples)
# 用iter会少序号,和enumerate有一点不同,其它差不多可以打印出来看
examples = iter(train_dl)
imgs, labels = next(examples)
# print(len(imgs), len(labels))

#-------------------------------数据显示--------------------------------------------
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(imgs[i][0],cmap='gray') # 4维数据,后面2维是28*28图片大小,不用管,关键是用前两维定位到图片
  plt.title('label:{}'.format(labels[i]))# 格式化输出,否则是tensor数据
  plt.xticks([])            # 清空x,y轴的数字
  plt.yticks([])

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值