一、使用MNIST数据集
本次学习使用神经网络识别手写数字,我们使用的数据集是MNIST数据集,MNIST数据集的长相如下图所示。
MNIST数据集是由0 到9 的数字图像构成。训练图像有6 万张,测试图像有1 万张,这些图像可以用于学习和推理。MNIST数据集的一般使用方法是,先用训练图像进行学习,再用学习到的模型度量能在多大程度上对测试图像进行正确的分类。 MNIST的图像数据是28 像素 × 28 像素的灰度图像(1 通道),各个像素的取值在0 到255 之间。每个图像数据都相应地标有“7”、“2”、“1”等标签。
load_mnist函数以“( 训练图像, 训练标签),( 测试图像,测试标签)”的形式返回读入的MNIST数据。
def load_mnist():
train_labels_path = 'train-labels.idx1-ubyte'
test_labels_path = 't10k-labels.idx1-ubyte'
train_images_path = 'train-images.idx3-ubyte'
test_images_path = 't10k-images.idx3-ubyte'
with open(train_labels_path, 'rb') as lpath:
magic, n = struct.unpack('>II', lpath.read(8))