读取CIFAR10数据集的数据并进行展示

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。

这里我们使用python读取CIFAR10其中的一部分数据并进行展示

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
torch.cuda.set_device(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 10
Epochs = 20


trans = torchvision.transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
cifar_10 = torchvision.datasets.CIFAR10(root='./data', train=True, transform=trans, download=False)
data_loader = DataLoader(cifar_10, batch_size=batch_size, shuffle=True)

def imshow(img):

    #反归一化,将数据重新映射到0-1之间
    img = img / 2 + 0.5
    plt.imshow(np.transpose(img.numpy(), (1,2,0)))
    plt.show()


for i, (images, _) in enumerate(data_loader):

    print(i)
    print(images.numpy().shape)
    imshow(images[0])
    break



  • ToTensor()可以将数据映射到0-1的范围内,方便进行计算,Normalize()通过设置均值和方差,可以将数据再次映射到-1到1之间,如原来的0有,(0 - 0.5) / 0.5 = -1 (1 - 0.5) / 0.5 = 1,得到最终映射后的结果
  • 但是在后面进行操作的时候,由于像素的值不可能是负值,所以我们还要进行反归一化操作,将数据重新映射到0-1范围内,具体方式是就是原来的数据除以2再加0.5 ,如:(-1/2) + 0.5 = 0 (1/ 2) + 0.5 = 1,可以将数据重新映射到0-1范围内
  • 经网络输出的图像一般的格式为(3,32,32),分别表示通道的数量,图像的宽度与高度,但是plt在显示的时候按照(32,32,3)的方式来进行读取,分别表示宽度,高度,通道的数量,所以使用np.transpose(1,2,0)来对每个维度的数据俩进行转换

显示的图像如图所示:
在这里插入图片描述
它的像素只是32*32的,像素值过低,所以显示的不是很清楚,我们可以将多张图片做成Grid的形式来展示,代码如下:

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
torch.cuda.set_device(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 50
Epochs = 20

trans = torchvision.transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
cifar_10 = torchvision.datasets.CIFAR10(root='./data', train=True, transform=trans, download=False)
data_loader = DataLoader(cifar_10, batch_size=batch_size, shuffle=True)

def imshow(img):

    #反归一化,将数据重新映射到0-1之间
    img = img / 2 + 0.5

    plt.imshow(np.transpose(img.numpy(), (1,2,0)))

    plt.show()


for i, (images, _) in enumerate(data_loader):

    print(i)
    print(images.numpy().shape)
    imshow(torchvision.utils.make_grid(images))
    break

得到的图像如下所示:
在这里插入图片描述

这些图像都是随机地从原来的数据集中提取的

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值