文章目录
前言
本文介绍如何使用 torchvision.datasets.CIFAR10 和 DataLoader 加载图像数据,并使用 TensorBoard 进行批量图像可视化,适合于新手学习。
一、环境准备与导入库
import torchvision.transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
二、DataLoader + TensorBoard 可视化 CIFAR-10 图像数据
1.数据集加载
代码如下:
test_data = torchvision.datasets.CIFAR10(
root="../database", # 数据下载目录(尽量保存到当前文件夹)
train=False, # 训练集False则加载测试集
transform=torchvision.transforms.ToTensor() # 图像转换为 Tensor
)
此处也可以使用Compose[]进行数据预处理,但是这里只是做一个简单的Totensor变换,如果想要对图像做更多的处理比如随机裁剪,扩大,缩小等等就需要使用Compose来处理了。
2.使用 DataLoader 批量加载数据
代码如下:
test_loader=DataLoader(
dataset=test_data,
batch_size=64,
shuffle=False,
num_workers=0,
drop_last=True
)
'''dataset=test_data, # 测试数据集
batch_size=4, # 每个 batch 含 4 条数据
shuffle=True, # 打乱数据顺序(通常为False)
num_workers=0, # 使用 0 个子进程(主线程加载数据)
drop_last=False # 如果最后一组不足 4 条数据,也保留(不会丢弃)'''
这里还是建议各位同学移步https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
查看更多关于DataLoader,了解更多详细信息知识
3.获取单张图像查看形状与标签
img, target = test_data[0]
print(img.shape) # 输出: torch.Size([3, 32, 32])
print(target) # 输出: 某张图像的类别索引(0~9)
输出结果如下:
![]()
3是RGB的红绿蓝三个通道(也就是彩色),32*32则指图像分辨率,可以知道画质依旧感人
3是标签
4.使用 TensorBoard 记录批量图像
writer=SummaryWriter("../logs3")
for i in range(2):
step=0#初始化步数
for data in test_loader:
img,target=data
# print(img.shape)
# print(target)
# writer.add_image("test_data",img,step,dataformats="CHW")是错误的
# writer.add_image("test_data",img,step,dataformats="NCHW")这个也对
writer.add_images(f"echo_{i}",img,step)#默认dataformats是NCHW(这个完全对)
step=step+
writer.close()
#报错信息
'''torch.utils.tensorboard.writer.add_image
方法默认期望接收格式为 CHW(通道 - 高度 - 宽度)的单张图像张量,
当传入形状为 NCHW(批量大小 - 通道 - 高度 - 宽度)的批量图像张量时,维度结构不配,
就会导致解析出错并抛出 AssertionError 报错。解决办法:将add_image变为add_images'''
add_images() 支持批量图像的记录,要求输入为 [N, C, H, W] 格式,默认 dataformats='NCHW'。
若使用 add_image() 则只能写入一张图像,且维度为 [C, H, W]。
add_images() 可自动处理每张图片,适合用于 DataLoader 加载的批次。
5.启动 TensorBoard 查看图像
在终端中运行以下命令,打开 TensorBoard:
tensorboard --logdir=logs3 --samples_per_plugin=images=1000
这里不用tensorboard --logdir=logs3默认只显示10张图像显示不了全部,使用这个命令后即可显示更多图像数据(最多1000张)。
6.可视化展示



总结
学习了如何使用 torchvision.datasets.CIFAR10 和 DataLoader 加载图像数据,并使用 TensorBoard 进行批量图像可视化
2421

被折叠的 条评论
为什么被折叠?



