计算机视觉2.17.2:监控深度学习的训练过程

上篇文章1.17.1 注意欠拟合和过拟合中,我们一起讨论了过拟合和欠拟合发生的现象以及导致的原因。

在本文中,我们将创建一个TrainingMonitor回调函数,在训练时的每个epoch结束时调用。监视器会将训练和验证的loss和准确率序列化到硬盘当中保存,并且以此为数据进行绘图。

通过这个回调函数,我们就可以照看我们的训练过程,及时发现过拟合,阻止实验的继续进行,以防计算资源和时间的白白浪费。

创建训练监视器

目录结构如下:

|----callbacks
|		|----__init__.py
|		|----trainingmonitor.py
|----nn
|----utils

打开training monitor.py文件,写入如下代码:

from tensorflow.keras.callbacks import BaseLogger
import matplotlib.pyplot as plt
import numpy as np
import json
import os


class TrainingMonitor(BaseLogger):
    def __init__(self, figPath, jsonPath=None, startAt=0):
        """
        :param figPath:     存放可视化loss和acc图像的地址
        :param jsonPath:    可选的路径,用于将loss和acc序列化成json文件
        :param startAt:     当用ctrl+c停止训练时的恢复的起始epoch
        """
        super(TrainingMonitor, self).__init__()
        self.figPath = figPath
        self.jsonPath = jsonPath
        self.startAt = startAt

    def on_train_begin(self, logs={}):
        # 定义H为losees的历史数值
        self.H = {}

        # 检查JSON路径是否提供,如果提供检查是否存在
        # 如果存在,那么读取其中内容并更新至历史字典H中
        if self.jsonPath is not None:
            if os.path.exists(self.jsonPath):
                self.H = json.loads(open(self.jsonPath).read())
                if self.startAt > 0:
                    for k in self.H.keys():
                        self.H[k] = self.H[k][:self.startAt]

    def on_epoch_end(self, epoch, logs={}):
        # 循环取日志中的值,更新loss和acc等
        for (k, v) in logs.items():
            l = self.H.get(k, [])
            l.append(v)
            self.H[k] = l

        # 如果提供了json路径,那么向其中写入H数据
        if self.jsonPath is not None:
            f = open(self.jsonPath, "w")
            f.write(json.dumps(self.H))
            f.close()

        # 确保绘图前至少已经过去了两个epochs
        if len(self.H["loss"]) > 1:
            N = np.arange(0, len(self.H["loss"]))
            plt.style.use("ggplot")
            plt.figure()
            plt.plot(N, self.H["loss"], label="train_loss")
            plt.plot(N, self.H["val_loss"], label="val_loss")
            plt.plot(N, self.H["accuracy"], label="train_acc")
            plt.plot(N, self.H["val_accuracy"], label="val_acc")
            plt.title("Training Loss and Accuracy [Epoch {}]".format(len(self.H["loss"])))
            plt.xlabel("Epoch#")
            plt.ylabel("lOSS/Accuracy")
            plt.legend()

            plt.savefig(self.figPath)
            plt.close()

监控训练

创建cifar10_monitor.py并写入如下代码:

import matplotlib
matplotlib.use("Agg")

from callback.trainingmonitor import TrainingMonitor
from sklearn.preprocessing import LabelBinarizer
from nn.conv.minivggnet import MiniVGGNet
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.datasets import cifar10
import os

# 输出数据和图像的路径
output = "/Users/lingg/PycharmProjects/DLstudy/monitor_result"
# 用操作系统分配的进程ID 来为绘图和JSON文件命名,
# 如果看到训练效果不好,可以直接用任务管理器杀掉进程
print("[INFO] process ID:{}".format(os.getpid()))

print("[INFO] loading CIFAR-10 data...")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0

lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

labelNames = ["airplane", "automobile", "bird", "cat", "deer", "dog",
              "frog", "horse", "ship", "truck"]

print("[INFO] compiling model...")
opt = SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
model = MiniVGGNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])

figPath = os.path.sep.join([output, "{}.png".format(os.getpid())])
jsonPath = os.path.sep.join([output, "{}.json".format(os.getpid())])
callbacks = [TrainingMonitor(figPath, jsonPath)]

print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=64, epochs=100, callbacks=callbacks, verbose=1)

查看/Users/lingg/PycharmProjects/DLstudy/monitor_result路径,会发现生成了两个以进程id命名的文件:xxx.json和xxx.png,分别是训练和验证的loss、准确率数据,以及所绘成的图像。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值