获取数据集
数据集简介
本节中将使用数据集Fashion-MNIST,Fashion-MNIST 中⼀共包括了 10 个类别,分别为:t-shirt(T 恤)、trouser(裤⼦)、pullover(套衫)、dress(连⾐裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和 ankle boot(短靴)。在该数据集中含有6000个训练集样本、1000个测试集样本。
为什么不用手写数据集做实验了呢?
原因为:
1、MNIST is too easy. Convolutional nets can achieve 99.7% on MNIST,太简单
2、MNIST is overused,过度使用
3、MNIST can not represent modern CV tasks, 过时
代码讲解
导入包
import gluonbook as gb
from mxnet.gluon import data as gdata
import sys
import time
获取数据集与测试集
#获取数据集与测试集
#MXNet的gdata.vision提供了FashionMNIST的数据集
mnist_train = gdata.vision.FashionMNIST(train = True)
mnist_test = gdata.vision.FashionMNIST(train = False)
将数值标签转化为文本标签
因为数据集含有10个类别,在数据中的标签为0-9,所以在图像显示时,我们想更加方便的知道图像对应的的类别名字,而不是冰冷的数字,那么就可以采用一下代码,将数字标签转化为与其对应的文本标签。
#将数值标签转化为文本标签
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
显示图像的函数
在使用代码的时候,难免想观察一下图像效果,通过该函数,可以显示出图像与其对应的标签。
def show_fashion_mnist(images, labels):
gb.use_svg_display()
_, figs = gb.plt.su