文章目录
import torch,math
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision.datasets as dsets
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as NN
torch.__version__
'1.2.0'
5.3 Fashion MNIST进行分类
Fashion MNIST 介绍
Fashion MNIST数据集 是kaggle上提供的一个图像分类入门级的数据集,其中包含10个类别的70000个灰度图像。如图所示,这些图片显示的是每件衣服的低分辨率(28×28像素)
数据集的下载和介绍:地址
Fashion MNIST的目标是作为经典MNIST数据的替换——通常被用作计算机视觉机器学习程序的“Hello, World”。
MNIST数据集包含手写数字(0-9等)的图像,格式与我们将在这里使用的衣服相同,MNIST只有手写的0-1数据的复杂度不高,所以他只能用来做“Hello, World”
而Fashion MNIST 的由于使用的是衣服的数据,比数字要复杂的多,并且图片的内容也会更加多样性,所以它是一个比常规MNIST稍微更具挑战性的问题。
Fashion MNIST这个数据集相对较小,用于验证算法是否按预期工作。它们是测试和调试代码的好起点。
数据集介绍
分类
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
格式
fashion-mnist_test.csv
fashion-mnist_train.csv
存储的训练的数据和测试的数据,格式如下:
label是分类的标签
pixel1-pixel784是每一个像素代表的值 因为是灰度图像,所以是一个0-255之间的数值。
为什么是784个像素? 28 * 28 = 784
数据提交
Fashion MNIST不需要我们进行数据的提交,数据集中已经帮助我们将 训练集和测试集分好了,我们只需要载入、训练、查看即可,所以Fashion MNIST 是一个非常好的入门级别的数据集
#指定数据目录
DATA_PATH=Path('./data/')
train = pd.read_csv(DATA_PATH / "fashion-mnist_train.csv");
train.head(10)
| label | pixel1 | pixel2 | pixel3 | pixel4 | pixel5 | pixel6 | pixel7 | pixel8 | pixel9 | ... | pixel775 | pixel776 | pixel777 | pixel778 | pixel779 | pixel780 | pixel781 | pixel782 | pixel783 | pixel784 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 6 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 5 | 0 | ... | 0 | 0 | 0 | 30 | 43 | 0 | 0 | 0 | 0 | 0 |
| 3 | 0 | 0 | 0 | 0 | 1 | 2 | 0 | 0 | 0 | 0 | ... | 3 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
| 4 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 5 | 4 | 0 | 0 | 0 | 5 | 4 | 5 | 5 | 3 | 5 | ... | 7 | 8 | 7 | 4 | 3 | 7 | 5 | 0 | 0 | 0 |
| 6 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 14 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 7 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 8 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 2 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 9 | 8 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 203 | 214 | 166 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
10 rows × 785 columns
test = pd.read_csv(DATA_PATH / "fashion-mnist_test.csv");
test.head(10)
| label | pixel1 | pixel2 | pixel3 | pixel4 | pixel5 | pixel6 | pixel7 | pixel8 | pixel9 | ... | pixel775 | pixel776 | pixel777 | pixel778 | pixel779 | pixel780 | pixel781 | pixel782 | pixel783 | pixel784 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 9 | 8 | ... | 103 | 87 | 56 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 34 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 14 | 53 | 99 | ... | 0 | 0 | 0 | 0 | 63 | 53 | 31 | 0 | 0 | 0 |
| 3 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 137 | 126 | 140 | 0 | 133 | 224 | 222 | 56 | 0 | 0 |
| 4 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 5 | 2 | 0 | 0 | 0 | 0 | 0 | 44 | 105 | 44 | 10 | ... | 105 | 64 | 30 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 6 | 8 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 7 | 6 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | ... | 174 | 136 | 155 | 31 | 0 | 1 | 0 | 0 | 0 | 0 |
| 8 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 57 | 70 | 28 | 0 | 2 | 0 | 0 | 0 | 0 | 0 |
10 rows × 785 columns
train.max()
label 9
pixel1 16
pixel2 36
pixel3 226
pixel4 164
...
pixel780 255
pixel781 255
pixel782 255
pixel783 255
pixel784 170
Length: 785, dtype: int64
ubyte文件标识了数据的格式
其中idx3的数字表示数据维度。也就是图像为3维,
idx1 标签维1维。
具体格式详解:http://yann.lecun.com/exdb/mnist/
import struct
from PIL import Image
with open(DATA_PATH / "train-images-idx3-ubyte", 'rb') as file_object:
header_data=struct.unpack(">4I",file_object.read(16))
print(header_data)
(2051, 60000, 28, 28)
with open(DATA_PATH / "train-labels-idx1-ubyte", 'rb') as file_object:
header_data=struct.unpack(">2I",file_object.read(8))
print(header_data)
(2049, 60000)
如下是训练的图片的二进制格式
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
有四字节的header_data,故使用unpack_from进行二进制转换时,偏置offset=16
with

本文通过构建CNN网络对FashionMNIST数据集进行分类任务,详细介绍了数据预处理、模型搭建、训练过程及评估方法。
最低0.47元/天 解锁文章
3615

被折叠的 条评论
为什么被折叠?



