VGG模型用于训练cifar数据集
结果
训练集中
测试集中
感觉在测试集上的准确率还是挺低的,感觉剩下的可以通过调参实现;
然而我并不想调参,只是想了解一下通过其他模型对比不同模型之间对同一数据集的差异;
神经网络训练过程中,损失误差及准确率的变化;
在调用训练的那个history时,有时会出现莫名其妙的错误
比如:
报错信息之KeyError: ‘accuracy’
报错信息之KeyError: 'val_acc’和KeyError: ‘acc’
show_train_history(history,'loss','accuracy')
就是上面那个代码中的loss等,这主要是由于keras的版本不一样导致的,因此你可以自行调看自己的keras版本的history的关键字
print("H.histroy keys:", history.history.keys())
还有就是,在将测试集扔进去验证时,曾经出现过这个错误
Failed to find data adapter that can handle input: <class ‘numpy.ndarray’>,
<class ‘numpy.ndarray’> contain <class ‘int’>(差不多这个,后面半段是凭记忆编写的)
好像是类型不匹配造成的,然后添加下面这个类型转换才解决掉
test_image = np.array(test_image)
test_label = np.array(test_label)
读取数据代码
import pickle
import numpy as np
def load_cifar10_batch(cifar10_dataset_folder_path, batch_id):
with open(cifar10_dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
# features and labels
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
labels = batch['labels']
return features, labels
def load_data():
# 加载所有数据
cifar10_path = 'D:/Code/Python/data/cifar-10-batches-py'
# 一共有5个batch的训练数据
x_train, y_train = load_cifar10_batch(cifar10_path, 1)
for i in range(2,5):
features,labels = load_cifar10_batch(cifar10_path, i)
x_train, y_train = np.concatenate([x_train, features]),np.concatenate([y_train, labels])
# 加载测试数据
with open(cifar10_path + '/test_batch', mode = 'rb') as file:
batch = pickle.load(file, encoding='latin1')
x_test = batch['data'].reshape((len(batch['data']),3,32,32)).transpose(0,2,3,1)
y_test = batch['labels']
return x_train,y_train,x_test,y_test
正文代码
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time
import math
import os
import keras
import cifar_load_data
max_steps = 3000
batch_size = 128
#***********************************载入数据************************************#
train_image,train_label,test_image,test_label = cifar_load_data.load_data()
#归一化处理
train_image = train_image.astype(np.float32) / 255.0 # 归一化处理
test_image = test_image.astype(np.float32) / 255.0
def build_model():
vgg_model = keras.models.Sequential()
vgg_model.add(keras.layers.Conv2D(32, (3,3), padding="same",activation="relu", input_shape=(32,32,3)))
vgg_model.add(keras.layers.BatchNormalization(momentum=0.9))
vgg_model.add(keras.layers.Conv2D(32, (3,3), padding="same",activation="relu"))
vgg_model.add(keras.layers.BatchNormalization(momentum=0.9))
vgg_model.add(keras.layers.MaxPool2D(pool_size=(2,2), padding="same"))
vgg_model.add((keras.layers.Dropout(rate=0.25)))
vgg_model.add(keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu"))
vgg_model.add(keras.layers.BatchNormalization(momentum=0.9))
vgg_model.add(keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu"))
vgg_model.add(keras.layers.BatchNormalization(momentum=0.9))
vgg_model.add(keras.layers.MaxPool2D(pool_size=(2, 2), padding="same"))
vgg_model.add((keras.layers.Dropout(rate=0.25)))
vgg_model.add(keras.layers.Flatten())
vgg_model.add(keras.layers.Dense(512, activation="relu"))
vgg_model.add(keras.layers.BatchNormalization(momentum=0.9))
vgg_model.add((keras.layers.Dropout(rate=0.25)))
vgg_model.add(keras.layers.Dense(10, activation="softmax"))
return vgg_model
if os.path.exists('D:/Code/Python/module/cifar_module.h5'):
module = keras.models.load_model('D:/Code/Python/module/cifar_module.h5')
else:
module = build_model()
module.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history = module.fit(train_image, train_label, epochs=50)
module.save('D:/Code/Python/module/cifar_module.h5')
test_image = np.array(test_image)
test_label = np.array(test_label)
result_loss,result_predition = module.evaluate(test_image,test_label,verbose=2)
def show_train_history(train_history,train,validation):
plt.plot(train_history.history[train])#训练数据的执行结果
plt.plot(train_history.history[validation])#验证数据的执行结果
plt.ylabel(train)
plt.xlabel('epoch')
plt.legend(['loss','accuracy'],loc='upper left')
plt.show()
print("验证集:",result_predition)
show_train_history(history,'loss','accuracy')