1.导入相应的库:
关于Fashion MNIST数据集的介绍:看这位博主:
https://blog.youkuaiyun.com/qq_28869927/article/details/85079808
import os
import cv2
import numpy as np
import tensorflow as tf
import keras
import functools
from keras.preprocessing.image import ImageDataGenerator
2.加载数据集:
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.fashion_mnist.load_data()
print(np.shape(x_train))
print(np.shape(y_train))
print(np.shape(x_test))
print(np.shape(y_test))
3.转换为独热编码:
# y_train=tf.keras.utils.to_categorical(y_train,num_classes=10)
# y_test=tf.keras.utils.to_categorical(y_test,num_classes=10)
y_train=tf.one_hot(y_train,depth=10).numpy()
y_test=tf.one_hot(y_test,depth=10).numpy()
print(tf.shape(y_train))
print(tf.shape(y_test))
4.归一化处理:
x_train,x_test=x_train.astype('float')/255.0,x_test.astype('float')/255.0
x_train=tf.reshape(x_train,[-1,28,28,1])
x_test=tf.reshape(x_test,[-1,28,28,1])
print(np.shape(x_train))
print(np.shape(y_train))
print(np.shape(x_test))
print(np.shape(y_test))
5.设置全局参数:
batch_Size=32
num_classes=10
EPOCHES=5
num_predictions=20
6.搭建传统的CNN模型:
model=tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28,28,1)),
tf.keras.layers.Conv2D(32,kernel_size=[3,3],strides=[1,1],padding='same'),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(32,kernel_size=[3,3],strides=[1,1]),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=[2,2]),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1],padding='same'),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1]),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=[2,2]),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes),
tf.keras.layers.Activation('softmax')
])
model.summary()
7.数据增强:
data=ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-06,
rotation_range=0,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.,
zoom_range=0.,
channel_shift_range=0.,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=None,
validation_split=0.0
)
了解更多关于ImageDataGenerator:可以看这位博主的文章:
https://blog.youkuaiyun.com/qq_27825451/article/details/90056896
train_gen=data.flow(x_train,y_train,batch_size=batch_Size)
test_gen=data.flow(x_test,y_test,batch_size=batch_Size)
8.模型优化器的选择和编译:
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001)
model.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy'] )
9.训练:
history=model.fit(
train_gen,
epochs=EPOCHES,
verbose=1,
validation_data=test_gen,
workers=5,
batch_size=batch_Size
)
model.save('fashion_Mnist_FM.h5')
训练结果:
10.查看history中的情况:
print(history.history)
11.画出accuracy,val_accuracy,loss,val_loss: