AlexNet网络诞生于2012年,是Hinton的代表做之一,当年的ImageNet竞赛的冠军,Top5错误率为16.4%。AlexNet使用relu激活函数提升了训练速度,使用Dropout缓解了过拟合。
AlexNet一共8层,5个卷积层,3个全连接层.
import numpy as np
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import PySide2
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPooling2D,Dropout,Flatten,Dense
from tensorflow.keras import Model
dirname = os.path.dirname(PySide2.__file__)
plugin_path = os.path.join(dirname, 'plugins', 'platforms')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_path
np.set_printoptions(threshold=np.inf) # 设置打印出所有参数,不要省略
mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 由于fashion数据集是三维的(60000, 28, 28),而cifar10 数据集是四维的,而此网络是用来识别四维的数据所所以需要将3维的输入扩展维4维的输入
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)
class AlexNet8(Model):
def __init__(self):
super(AlexNet8, self).__init__()
self.c1 = Conv2D(filters=96, kernel_size=(3, 3), padding='valid')
self.b1 = BatchNormalization()
self.a1 = Activation('relu')
self.p1 = MaxPooling2D(pool_size=(3, 3), strides=2, padding='valid')
self.c2 = Conv2D(filters=256, kernel_size=(3, 3), padding='valid')
self.b2 = BatchNormalization()
self.a2 = Activation('relu')
self.p2 = MaxPooling2D(pool_size=(3, 3), strides=2, padding='valid')
self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu')
self.c4 = Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu')
self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu')
self.p3 = MaxPooling2D(pool_size=(3, 3), strides=2, padding='valid')
self.flatten = Flatten()
self.f1 = Dense(2048, activation='relu')
self.d1 = Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax')
def call(self, x):
x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x)
x = self.c2(x)
x = self.b2(x)
x = self.a2(x)
x = self.p2(x)
x = self.c3(x)
x = self.c4(x)
x = self.c5(x)
x = self.p3(x)
x = self.flatten(x)
x = self.f1(x)
x = self.d1(x)
x = self.f2(x)
x = self.d2(x)
y = self.f3(x)
return y
model = AlexNet8()
print(model(x_test).shape)
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
print('------------------------load the model---------------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=10, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1,
callbacks=[cp_callback])
model.summary()
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
############################ show #############################
# 显示训练集和验证集的acc和loss曲线
acc=history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()