使用 PyTorch 的 torchvision 库加载 CIFAR-10 数据集

部署运行你感兴趣的模型镜像

CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
         每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。
 

import torchvision
import torchvision.transforms as transforms

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 自动下载训练集
trainset = torchvision.datasets.CIFAR10(
    root='./data',  # 数据保存路径
    train=True,
    download=True,  # 设置为True自动下载
    transform=transform
)

# 自动下载测试集
testset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

1. 导入必要的库

import torchvision

import torchvision.transforms as transforms

  1. torchvision:PyTorch 的视觉库,提供常用数据集、模型架构和图像转换工具。
  2. transforms:用于图像预处理的模块,如缩放、归一化等。

2. 定义数据预处理流程

transform = transforms.Compose([

    transforms.ToTensor(),

    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

])

  1. transforms.Compose:将多个预处理操作按顺序组合。
  2. transforms.ToTensor()
    1. 将 PIL 图像或 NumPy 数组(H×W×C,范围 0-255)转换为 PyTorch 张量(C×H×W,范围 0.0-0)。
  3. transforms.Normalize(mean, std)
    1. 对每个通道进行归一化:output = (input - mean) / std
    2. 这里mean=(0.5, 0.5, 0.5)std=(0.5, 0.5, 0.5)将像素值从[0.0, 1.0]映射到[-1.0, 1.0](例如,0.0→-1.0,1.0→1.0)。

3. 下载并加载训练集

trainset = torchvision.datasets.CIFAR10(

    root='./data',  # 数据保存路径

    train=True,     # True表示训练集(50,000张)

    download=True,  # 自动下载(如果数据不存在)

    transform=transform  # 应用预处理

)

  1. torchvision.datasets.CIFAR10:CIFAR-10 数据集类,包含 10 个类别(如飞机、汽车、鸟类等)的 60,000 张 32×32 彩色图像。
  2. 参数说明
    1. root='./data':数据将下载到当前目录的data文件夹中。
    2. train=True:加载训练集(50,000 张);若为False则加载测试集(10,000 张)。
    3. download=True:若数据不存在,自动从互联网下载(约 170MB)。
    4. transform=transform:对图像应用之前定义的预处理(转为张量并归一化)。

4. 下载并加载测试集

testset = torchvision.datasets.CIFAR10(

    root='./data',  # 与训练集路径一致

    train=False,    # 加载测试集

    download=True,  # 自动下载

    transform=transform  # 应用相同的预处理

)

  1. 测试集与训练集结构相同,但用于模型评估,不参与训练。

5. 数据验证与使用

下载完成后,数据将存储在./data/cifar-10-batches-py目录中。你可以:

  1. 查看数据集大小

print(len(trainset))  # 输出: 50000

print(len(testset))   # 输出: 10000

  1. 访问单个样本

image, label = trainset[0]  # 获取第一张图像及其标签

print(image.shape)  # 输出: (3, 32, 32)

print(label)        # 输出: 6(对应类别索引)

  1. 使用数据加载器批量处理数据

from torch.utils.data import DataLoader

trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

testloader = DataLoader(testset, batch_size=32, shuffle=False)

注意事项

  1. 下载路径
    1. 若指定路径(如./data)已存在 CIFAR-10 数据,download=True不会重复下载。
    2. 若路径错误或无写入权限,会抛出异常(如PermissionError)。
  2. 网络问题
    1. 首次下载需联网,可能需要几分钟。若下载中断,可删除./data目录后重新运行。
  3. 数据预处理
    1. 归一化参数meanstd通常根据数据集的统计特性设定。对于 CIFAR-10,常用(0.5, 0.5, 0.5)进行简单归一化。
    2. 若需要更精确的归一化,可计算数据集的真实均值和标准差(如mean=[0.4914, 0.4822, 0.4465]std=[0.2470, 0.2435, 0.2616])。

扩展应用

加载数据后,可用于训练 CNN 模型(如之前创建的SimpleCNN):

# 假设model已定义

from torch import nn, optim

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练循环

for epoch in range(5):  # 训练5个轮次

    for inputs, labels in trainloader:

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = criterion(outputs, labels)

        loss.backward()

        optimizer.step()

    print(f"Epoch {epoch+1} completed")

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值