数据集准备
FashionMNIST数据集作为一个比较入门的图像分类数据集,数量较少,个人认为是一个很好的入门数据集。其中包含了10个不同类别的服装和鞋类商品的灰度图像。每个类别包含6000张训练图像和1000张测试图像,共计70000张图像。
这些图像的分辨率是28x28像素,它们相对较小,适合用于快速原型设计和实验。FashionMNIST数据集通常用于测试深度学习模型在图像分类任务上的性能,并作为MNIST数据集的一种替代选择。
FashionMNIST数据集中的10个类别分别是:
T-shirt/top (T恤/上衣)
Trouser (裤子)
Pullover (套头衫)
Dress (裙子)
Coat (外套)
Sandal (凉鞋)
Shirt (衬衫)
Sneaker (运动鞋)
Bag (包)
Ankle boot (短靴)
数据集导入:
train_dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root="./data", train=False, transform=transform, download=True)
可以选择在线下载,网络不好的好提前离线下载好。
分类实现
首先导入工具包:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
from time import time
然后进行数据准备:
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root=

文章介绍了如何使用FashionMNIST数据集进行入门级图像分类任务,包括数据集导入、构建简单CNN模型、在PyTorch中进行训练和评估,最终得到91.59%的准确率。
最低0.47元/天 解锁文章
4261

被折叠的 条评论
为什么被折叠?



