matplotlib训练可视化

本文介绍了如何利用matplotlib简化CycleGAN训练结果的展示,包括在每个epoch后保存图片以及通过动态展示来观察训练过程。代码示例中展示了如何创建子图并保存每个epoch的训练结果,同时提供了动态展示所有训练结果的函数,使得观察模型训练进度更加直观便捷。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

计算机视觉方向的训练如果看不到图片的话就不太好直观的判断出训练结果。
例如现在要训练一个CycleGAN网络,目标是输入一个杰哥,输出一个阿伟,那么下面这种展示效果就可以直观的判断出训练是否成功↓

在这里插入图片描述
代码如下↓

if __name__ == '__main__':

    # 声明一个图框 ----------------------------------------------

    fig = plt.figure()

    # 往图框中插入一个子图, 1行2列的大小,放在第1个位置----
    axl = fig.add_subplot(121)
    axl.set_title('input img')              # 设置子图标题
    axl.axis('off')                         # 子图是否开启坐标 on/off
    img = Image.open('input_img.jpg')       # 打开一张图片
    axl.imshow(img)                         # 让子图展示ta
    # ----------------------------------------------

    # 往图框中插入一个子图, 1行2列的大小,放在第2个位置----
    axl = fig.add_subplot(122)
    axl.set_title('output img')             # 设置子图标题
    axl.axis('off')                         # 子图是否开启坐标 on/off
    img = Image.open('output_img.jpg')      # 打开一张图片
    axl.imshow(img)                         # 让子图展示ta
    # ----------------------------------------------

    # ----------------------------------------------------------

    # 保存图框
    # fig.savefig("result.png")

    # plt展示绘制的全部内容
    plt.show()


这样看还行,不过依然有很多问题:
1.代码比较多
2.训练过程中的图片不需要从路径中读取
3.如果我要每个epoch都打印出来看结果的话,那么展示结果的图片会有很多,名字不能重复

可以修改成这样

import matplotlib.pyplot as plt
from PIL import Image


def sample_img(input_img, output_img, epoch):
    '''
    每过一个epoch就让我康康结果
    :param input_img: 输入图片
    :param output_img: 输出图片
    :param epoch: 当前训练轮数
    :return:
    '''
    # 声明一个图框 ----------------------------------------------

    fig = plt.figure()

    # 往图框中插入一个子图, 1行2列的大小,放在第1个位置----
    axl = fig.add_subplot(121)
    axl.set_title('input img')      # 设置子图标题
    axl.axis('off')                 # 子图是否开启坐标 on/off
    axl.imshow(input_img)           # 让子图展示ta
    # ----------------------------------------------

    # 往图框中插入一个子图, 1行2列的大小,放在第2个位置----
    axl = fig.add_subplot(122)
    axl.set_title('output img')     # 设置子图标题
    axl.axis('off')                 # 子图是否开启坐标 on/off
    axl.imshow(output_img)          # 让子图展示ta
    # ----------------------------------------------

    # ----------------------------------------------------------

    # 保存图框
    fig.savefig("result_{0}.png".format(epoch))


if __name__ == '__main__':

    # 随便整点模拟数据 -------------------------
    input_img = Image.open('input_img.jpg')
    output_img = Image.open('output_img.jpg')
    epoch = 5
    # ---------------------------------------

    # 保存训练结果
    sample_img(input_img, output_img, epoch)
    

这样每经过一个epoch调用一下这个函数就可以把每轮训练结果都保存到当前目录了


如果你觉得一张一张查看训练结果比较麻烦,你还可以试试下面这段代码,可以让matplotlib动态展示所有训练结果,就像播放动画一样。

在这里插入图片描述

import matplotlib.pyplot as plt
from PIL import Image


def get_all_file_in_dir(dir_path):
    '''
    获取目录下的所有文件
    :return:
    '''
    file_name_list = []
    for root, dirs, files in os.walk(dir_path):
        if files:
            for name in files:
                file_name = '{0}/{1}'.format(root, name).replace('\\', '/')
                file_name_list.append(file_name)
    return file_name_list


def show_train_results(input_img_list, output_img_list):
    '''
    动态展示训练结果
    :param input_img_list: 一组输入图片
    :param output_img_list: 对应的一组输出图片
    :return:
    '''
    # 声明一个图框 ----------------------------------------------
    fig = plt.figure()

    # 往图框中插入一个子图, 1行2列的大小,放在第1个位置----
    axl1 = fig.add_subplot(121)
    axl1.set_title('input img')  # 设置子图标题
    axl1.axis('off')  # 子图是否开启坐标 on/off
    # ----------------------------------------------

    # 往图框中插入一个子图, 1行2列的大小,放在第2个位置----
    axl2 = fig.add_subplot(122)
    axl2.set_title('output img')  # 设置子图标题
    axl2.axis('off')  # 子图是否开启坐标 on/off
    # ----------------------------------------------

    # ----------------------------------------------------------

    for img_tup in list(zip(input_img_list, output_img_list)):
        axl1.imshow(Image.open(img_tup[0]))  # 让子图展示ta
        axl2.imshow(Image.open(img_tup[1]))  # 让子图展示ta
        plt.pause(0.01)

    plt.show()


if __name__ == '__main__':

    # 随便整点模拟数据 -------------------------------------------------------------------------------------
    img_list = get_all_file_in_dir(r'D:\机器学习数据集\celebA(分卷形式,一起解压)\celebA\img\img_align_celeba')
    input_img_list = img_list[0: 500]
    output_img_list = img_list[500: 1000]
    # ---------------------------------------------------------------------------------------------------

    # 动态展示
    show_train_results(input_img_list, output_img_list)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

什么都干的派森

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值