keras手写数字识别
手写数字识别,可以说是机器学习领域的“hello world”。
对于初学者来说,这可能是一个很好的案例。程序是在jupyter上面跑的,所以有很多中间结果,参考书目是《Python深度学习》。
import keras
from keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np
Using TensorFlow backend.
#获取mnist数据集(如果本地没有,会连网下载)
#训练集image格式(60000,28,28)
#训练集label格式(60000,)
#测试集image格式(10000,28,28)
#测试集label格式(10000,)
(train_img,train_label),(test_img,test_label)=mnist.load_data()
#拷贝一份测试数据以备后续预测
test_predict=test_img.copy()
#分离训练集和验证集
val_img=train_img[:30000]
train_img=train_img[30000:]
val_label=train_label[:30000]
train_label=train_label[30000:]
#将图片展平+归一化(0~1)
train_img=train_img.reshape((-1,28*28)).astype('float64')
train_img=train_img/255
#验证集
val_img=val_img