# Dataset只是去告诉我们程序,我们的数据集在什么位置,数据集第一个数据给它一个索引0,它对应的是哪一个数据。
# Dataloader就是把数据加载到神经网络当中,Dataloader所做的事就是每次从Dataset中取数据,至于怎么取,是由Dataloader中的参数决定的。
import torchvision.transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试用的数据集
test_data = torchvision.datasets.CIFAR10(root='D:\PyCharm\CIFAR10', train=False,
transform=torchvision.transforms.ToTensor())
img, target = test_data[0]
print(img.shape)
print(img)
# batch_size=4 使得 img0, target0 = dataset[0]、img1, target1 = dataset[1]、img2, target2 = dataset[2]、img3, target3 = dataset[3],然后这四个数据作为Dataloader的一个返回
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# 用for循环取出DataLoader打包好的四个数据
for data in test_loader:
imgs, targets = data # 每个data都是由4张图片组成,imgs.size 为 [4,3,32,32],四张32×32图片三通道,targets由四个标签组成
print(imgs)
print(targets)
# Tensorboard展示
test_data = torchvision.datasets.CIFAR10(root='D:\PyCharm\CIFAR10', train=False,
transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
writer = SummaryWriter("logs7")
step = 0
for data in test_loader:
imgs, targets = data
# add_image函数期望的输入格式是(H, W, C)
# 即(高度, 宽度, 通道数),但是您提供的数据格式是(C, H, W)
# 即(通道数, 高度, 宽度)。
# add_image函数期望的输入是一个单一的图像,而不是一个批次的数据。
# imgs是一个形状为(64, 3, 32, 32)的张量,这意味着它是一个批次,包含了64张3通道的32x32大小的图像。add_image函数无法直接处理这样的批次数据。
for i, img in enumerate(imgs):
# 图像通常以 [批次大小, 通道数, 高度, 宽度] 的顺序存储,即 (N, C, H, W)。
# 将图像的维度顺序从 (C, H, W) 改变为 (H, W, C)
img_hwc = img.permute(1, 2, 0)
# 为每一张图像添加一个标签,例如 'test_data/0', 'test_data/1', ...
writer.add_image(f'test_data/{i}', img_hwc, step, dataformats='HWC')
step = step + 1
writer.close()
#Dataloader多轮次
test_data=torchvision.datasets.CIFAR10(root='D:\PyCharm\CIFAR10',train=False,transform=torchvision.transforms.ToTensor())
test_loader=DataLoader(dataset=test_data, batch_size=64,shuffle=True,num_workers=0,drop_last=True)
#drop_last=True 这个参数用于指定在批次划分数据时,如果数据总数不能被批次大小整除,是否丢弃最后一个不完整的批次。
writer=SummaryWriter("logs8")
for epoch in range(2):
step=0
for data in test_loader:
imgs,targets=data
for i,img in enumerate(imgs):
img_hwc=img.permute(1, 2, 0)
writer.add_image(f"Epoch/{epoch},test_data{i}", img_hwc, step,dataformats="HWC")
step = step + 1
writer.close()
dataloader使用
最新推荐文章于 2024-10-29 19:26:56 发布