导入模块
from keras.datasets import mnist
from scipy import signal
import matplotlib.pyplot as plt
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
import tensorflow as tf
加载数据集
(X_train, y_train),(X_test, y_test) = mnist.load_data()
更改输入数据集结构为一维向量
X_train = X_train.reshape(X_train.shape[0],num_pixels).astype('float32')
X_train.shape
X_test = X_test.reshape(X_test.shape[0],num_pixels).astype('float32')
X_test.shape
归一化输入像素
# normalize inputs from 0-255 to 0-1
X_train = X_train / 255.0
X_test = X_test / 255.0
将标签用one-hot形式编码
# 将十个数字用一维向量编码
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
建立全连接层神经网络
num_classes = 10
model = Sequential()
model.add(Dense(num_pixels, input_dim=num_pixels, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
训练模型,查看模型情况
history = model.fit(X_train,y_train, validation_data=(X_test, y_test),epochs=10,batch_size=128,verbose=True)
model.summary()
使用卷积神经网络模型
输入图片形式
# 卷积神经网络
X_train = X_train.reshape(X_train.shape[0],28,28,1).astype('float32')
X_test = X_test.reshape(X_test.shape[0],28,28,1).astype('float32')
卷积神经网络模型
model = Sequential()
model.add(Conv2D(32 ,3 ,input_shape=(28,28,1),activation='relu'))
model.add(Conv2D(32 ,3 ,activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(64 ,3 ,activation='relu'))
model.add(Conv2D(64 ,3 ,activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
训练模型,查看正确率和模型情况
history = model.fit(X_train,y_train, validation_data=(X_test, y_test),epochs=10,batch_size=128,verbose=True)
score = model.evaluate(X_test, y_test, verbose=0)
print("Large CNN Error: %.2f%%" %(100-score[1]*100))
model.summary()