layout: post
title: 深度学习入门 基于Python的理论实现
subtitle: 第三章 手写数字识别
tags: [Machine learning, Reading]
第三章 神经网络
3.6 手写数字识别
上一个post介绍了神经网络的基本内容,这一节搭配项目解决实际问题。这个例子非常简单,是一个机器学习里的Hello world。手写数字识别问题。但是这个例子是不完全的,我们假设学习已经全部完成,我们用学习到的参数,先实现神经网络的“推理处理”。这也叫神经网络的前向传播。
3.6.1 MNIST数据集
这个数据集网上的资料实在太多了,就连他的进阶版本Fashion MNIST也出来很久了,相信能看到现在的人没有太多人不知道这个数据集。
介绍简单带过。MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。
MNIST的图像是28 × \times × 28像素的灰度图像(1通道),像素的取值在0到255之间。每个图像都标有对应的阿拉伯数字标签。
这本书提供了数据集和相应的代码。传送门
import sys,os
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
(x_train, t_train),(x_test, t_test) = load_mnist(flatten=True, normalize=False)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
第一次执行代码可能比较慢,原因是需要下载,服务器在国外,下载的比较慢。也可以手动下载放在文件夹,我就是用的这个方法(自动下载实在太慢了)。
于是我们打印出训练集,测试集和对应的label的shape。
这里对代码做一点简单说明,这里的load_mnist函数是将数据集做导入,分别为两个训练集两个测试集,flatten参数为True代表将28 × \times × 28的图像扁平化,变成1 × \times × 784的向量。normalize的含义是将数值标准化为0到1之间的数字,这个函数还可以传入一个参数,就是one_hot_label,这个参数设置为True将会让标签变为one hot representation。
因为这里并不涉及参数的训练,因此我们需要导入参数,这离有一个pkl文件,保存着训练好的参数,直接导入就可以。下来简单显示一下图片。
import sys,os,cv2
import numpy as np
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
(x_train, t_train),(x_test, t_test) = load_mnist(flatten=True, normalize=True)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
first_image = x_train[0]
first_label = t_train[0]