pytorch学习笔记4-torchvision中数据集的使用

本文介绍了如何在PyTorch中使用torchvision模块加载和预处理CIFAR10数据集,包括下载、查看数据集结构、应用transform转换为tensor,并利用tensorboard进行可视化。
部署运行你感兴趣的模型镜像

pytorch官网的torchvision中可以看到提供的数据集以及详细的说明,现成的数据集使用起来也很方便。

以CIFAR10为例,点开可以看到CIFAR10的参数和返回值的相关介绍:

import torchvision

# 参数train为TRUE则返回训练集,为FALSE则返回测试集,download设置为TRUE则自动从网上下载
train_set = torchvision.datasets.CIFAR10(root="./P14_dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./P14_dataset", train=False, download=True)
'''等待下载数据集'''

可以用debug查看数据集的组成:

可看出有classes的属性等等。

print(test_set[0])    # 测试集第一个由什么组成
print(test_set.classes)    # 查看属性

img, target = test_set[0]    # 测试集每一个由图片和索引组成,提取第一个
print(img, "\n", target)
print(test_set.classes[target])    # 查看第一个的属性,为cat
img.show()    # 用电脑自带的应用查看图片

注意img.show()查看的是PIL类的图片

与transform结合,把图片转换为tensor类型并用tensorboard查看:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./P14_dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./P14_dataset", train=False, transform=dataset_transform, download=True)

# print(test_set[0])
# print(test_set.classes)

# img, target = test_set[0]
# print(img, "\n", target)
# print(test_set.classes[target])
# img.show()

writer = SummaryWriter("P14")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()

终端输入下面的命令打开tensorboard:

tensorboard --logdir=P14 --port=6008  

查看结果:

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值