PyTorch 中使用 DataLoader + TensorBoard 可视化 CIFAR-10 图像数据

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

 

文章目录


前言

本文介绍如何使用 torchvision.datasets.CIFAR10 和 DataLoader 加载图像数据,并使用 TensorBoard 进行批量图像可视化,适合于新手学习。


 

一、环境准备与导入库

import torchvision.transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

二、DataLoader + TensorBoard 可视化 CIFAR-10 图像数据

1.数据集加载

代码如下:

test_data = torchvision.datasets.CIFAR10(
    root="../database",         # 数据下载目录(尽量保存到当前文件夹)
    train=False,                # 训练集False则加载测试集
    transform=torchvision.transforms.ToTensor()  # 图像转换为 Tensor
)

 此处也可以使用Compose[]进行数据预处理,但是这里只是做一个简单的Totensor变换,如果想要对图像做更多的处理比如随机裁剪,扩大,缩小等等就需要使用Compose来处理了。

2.使用 DataLoader 批量加载数据

代码如下:

test_loader=DataLoader(
dataset=test_data,
batch_size=64,
shuffle=False,
num_workers=0,
drop_last=True
)
'''dataset=test_data,      # 测试数据集
    batch_size=4,           # 每个 batch 含 4 条数据
    shuffle=True,           # 打乱数据顺序(通常为False)
    num_workers=0,          # 使用 0 个子进程(主线程加载数据)
    drop_last=False         # 如果最后一组不足 4 条数据,也保留(不会丢弃)'''

这里还是建议各位同学移步https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

查看更多关于DataLoader,了解更多详细信息知识

 3.获取单张图像查看形状与标签

img, target = test_data[0]
print(img.shape)  # 输出: torch.Size([3, 32, 32])
print(target)     # 输出: 某张图像的类别索引(0~9)

输出结果如下:

 

3是RGB的红绿蓝三个通道(也就是彩色),32*32则指图像分辨率,可以知道画质依旧感人
3是标签

4.使用 TensorBoard 记录批量图像 

writer=SummaryWriter("../logs3")
for i in range(2):

    step=0#初始化步数
    for data in test_loader:
        img,target=data
        # print(img.shape)
        # print(target)
        # writer.add_image("test_data",img,step,dataformats="CHW")是错误的
        # writer.add_image("test_data",img,step,dataformats="NCHW")这个也对

        writer.add_images(f"echo_{i}",img,step)#默认dataformats是NCHW(这个完全对)


        step=step+


writer.close()

 #报错信息
'''torch.utils.tensorboard.writer.add_image 
方法默认期望接收格式为 CHW(通道  -  高度  -  宽度)的单张图像张量,
当传入形状为 NCHW(批量大小 - 通道 - 高度 - 宽度)的批量图像张量时,维度结构不配,
就会导致解析出错并抛出 AssertionError 报错。

解决办法:将add_image变为add_images'''

add_images() 支持批量图像的记录,要求输入为 [N, C, H, W] 格式,默认 dataformats='NCHW'。
若使用 add_image() 则只能写入一张图像,且维度为 [C, H, W]。
add_images() 可自动处理每张图片,适合用于 DataLoader 加载的批次。 

5.启动 TensorBoard 查看图像 

在终端中运行以下命令,打开 TensorBoard:

tensorboard --logdir=logs3 --samples_per_plugin=images=1000

 这里不用tensorboard --logdir=logs3默认只显示10张图像显示不了全部,使用这个命令后即可显示更多图像数据(最多1000张)。

6.可视化展示 


总结

学习了如何使用 torchvision.datasets.CIFAR10 和 DataLoader 加载图像数据,并使用 TensorBoard 进行批量图像可视化

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值