计算机视觉方向的训练如果看不到图片的话就不太好直观的判断出训练结果。
例如现在要训练一个CycleGAN网络,目标是输入一个杰哥,输出一个阿伟,那么下面这种展示效果就可以直观的判断出训练是否成功↓
代码如下↓
if __name__ == '__main__':
# 声明一个图框 ----------------------------------------------
fig = plt.figure()
# 往图框中插入一个子图, 1行2列的大小,放在第1个位置----
axl = fig.add_subplot(121)
axl.set_title('input img') # 设置子图标题
axl.axis('off') # 子图是否开启坐标 on/off
img = Image.open('input_img.jpg') # 打开一张图片
axl.imshow(img) # 让子图展示ta
# ----------------------------------------------
# 往图框中插入一个子图, 1行2列的大小,放在第2个位置----
axl = fig.add_subplot(122)
axl.set_title('output img') # 设置子图标题
axl.axis('off') # 子图是否开启坐标 on/off
img = Image.open('output_img.jpg') # 打开一张图片
axl.imshow(img) # 让子图展示ta
# ----------------------------------------------
# ----------------------------------------------------------
# 保存图框
# fig.savefig("result.png")
# plt展示绘制的全部内容
plt.show()
这样看还行,不过依然有很多问题:
1.代码比较多
2.训练过程中的图片不需要从路径中读取
3.如果我要每个epoch都打印出来看结果的话,那么展示结果的图片会有很多,名字不能重复
可以修改成这样
import matplotlib.pyplot as plt
from PIL import Image
def sample_img(input_img, output_img, epoch):
'''
每过一个epoch就让我康康结果
:param input_img: 输入图片
:param output_img: 输出图片
:param epoch: 当前训练轮数
:return:
'''
# 声明一个图框 ----------------------------------------------
fig = plt.figure()
# 往图框中插入一个子图, 1行2列的大小,放在第1个位置----
axl = fig.add_subplot(121)
axl.set_title('input img') # 设置子图标题
axl.axis('off') # 子图是否开启坐标 on/off
axl.imshow(input_img) # 让子图展示ta
# ----------------------------------------------
# 往图框中插入一个子图, 1行2列的大小,放在第2个位置----
axl = fig.add_subplot(122)
axl.set_title('output img') # 设置子图标题
axl.axis('off') # 子图是否开启坐标 on/off
axl.imshow(output_img) # 让子图展示ta
# ----------------------------------------------
# ----------------------------------------------------------
# 保存图框
fig.savefig("result_{0}.png".format(epoch))
if __name__ == '__main__':
# 随便整点模拟数据 -------------------------
input_img = Image.open('input_img.jpg')
output_img = Image.open('output_img.jpg')
epoch = 5
# ---------------------------------------
# 保存训练结果
sample_img(input_img, output_img, epoch)
这样每经过一个epoch调用一下这个函数就可以把每轮训练结果都保存到当前目录了
如果你觉得一张一张查看训练结果比较麻烦,你还可以试试下面这段代码,可以让matplotlib动态展示所有训练结果,就像播放动画一样。
import matplotlib.pyplot as plt
from PIL import Image
def get_all_file_in_dir(dir_path):
'''
获取目录下的所有文件
:return:
'''
file_name_list = []
for root, dirs, files in os.walk(dir_path):
if files:
for name in files:
file_name = '{0}/{1}'.format(root, name).replace('\\', '/')
file_name_list.append(file_name)
return file_name_list
def show_train_results(input_img_list, output_img_list):
'''
动态展示训练结果
:param input_img_list: 一组输入图片
:param output_img_list: 对应的一组输出图片
:return:
'''
# 声明一个图框 ----------------------------------------------
fig = plt.figure()
# 往图框中插入一个子图, 1行2列的大小,放在第1个位置----
axl1 = fig.add_subplot(121)
axl1.set_title('input img') # 设置子图标题
axl1.axis('off') # 子图是否开启坐标 on/off
# ----------------------------------------------
# 往图框中插入一个子图, 1行2列的大小,放在第2个位置----
axl2 = fig.add_subplot(122)
axl2.set_title('output img') # 设置子图标题
axl2.axis('off') # 子图是否开启坐标 on/off
# ----------------------------------------------
# ----------------------------------------------------------
for img_tup in list(zip(input_img_list, output_img_list)):
axl1.imshow(Image.open(img_tup[0])) # 让子图展示ta
axl2.imshow(Image.open(img_tup[1])) # 让子图展示ta
plt.pause(0.01)
plt.show()
if __name__ == '__main__':
# 随便整点模拟数据 -------------------------------------------------------------------------------------
img_list = get_all_file_in_dir(r'D:\机器学习数据集\celebA(分卷形式,一起解压)\celebA\img\img_align_celeba')
input_img_list = img_list[0: 500]
output_img_list = img_list[500: 1000]
# ---------------------------------------------------------------------------------------------------
# 动态展示
show_train_results(input_img_list, output_img_list)