第一步:读入数据
# 导入必要的库
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import numpy as np
import matplotlib.pyplot as plt
# 获取数据集
(x, x_lable), (y, y_lable) = datasets.cifar10.load_data()
对数据进行归一化处理可以使得数据运算速度加快,同时减少异常数据带来的影响。本次数据集为图片,分布范围为0~255,仅需要将每一个数值除以255即可将数据集归纳到0~1之间。
x = x/255
y = y/255
第二步:设置回调函数
在模型训练过程中,我们无法对模型进行相关性的操作,此时就需要使用到tensflow中的回调函数了。
我们可以指定一个很大的epoch(训练轮数),当验证集的损失值在一定次数内都没有降低时,代表模型已经运行到了最优值附近,当前的学习率已经无法使得梯度继续下降了,此时就通过回调函数终止模型的训练。
# 提前结束训练
earlyStop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=25)
第一个参数代表监控模型训练过程中的参数
第二个参数代表能够容忍模型监控值没有下降的次数。
在模型的训练过程中,有时会出现一个良好的模型,此时我们可以通过回调函数保存这个模型。
# 设置模型保存节点
checkpoint_save_path = '.\\tmp\\model_4.h5'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
monitor='val_loss', mode='min',
save_best_only=True)
第一个参数代表保存模型的路径
第二个参数代表监模型的值,
第三个参数代表