pytorch学习笔记(五)——综合实践
以一个完整的多分类项目为例,构建简单的CNN网络,对FashionMNIST数据集分类
导入库
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
得到数据集
# 包含 10 类时尚物品的灰度图像数据集,每张图像的大小为 28x28 像素
train_data = datasets.FashionMNIST(
root = "data", # 表示数据集将存储在当前目录下的 data 文件夹中
train = True, # 指定是否加载训练集,如果为 False,则加载测试集
download = True, # 如果数据集不存在,则自动下载
transform = ToTensor(), # 图像数据转换为Tensor
target_transform = None # 表示不对标签数据进行任何转换
)
test_data = datasets.FashionMNIST(
root = 'data',
train = False,
download = True,
transform = ToTensor()
)
可视化
image, label = train_data[0] # 查看第一张图片
plt.imshow(image.squeeze())
plt.title(label)
class_names = train_data.classes
# 从数据集中随机选取16张图片展示
for i in range(1,17):
# 随机生成图片序号
random_idx = torch.randint(0, len(train_data), size = [1]).item()
img, label = train_data[random_idx]
# 在fig中添加一个4*4的子图,用于显示当前图像
fig.add_subplot(4, 4, i)
plt.imshow