前一段时间,训练2分类深度网络时,loss一直维持在2.3左右。在网上看了很多博客,最后从这篇博客中找到了,解决的方法。具体的不详细说了,可以参考那篇博客。我按照上面的提到问题,仔细检测自己的网络,发现可能是我的数据集中的图片和标签不一致所造成。
我的数据集是tfrecords格式的二进制文件(前提:确保制作数据集的图片不存在问题,不然读取时会报错。),我制作代码如下:
import glob
import tensorflow as tf
from PIL import Image
import numpy as np
import random
num=0
bestnum=5000
recordfilenum=0
filenames=[]
for filename in glob.glob('./data/PetImages/Cat/*.jpg')[2500:3000]:
tmp=[]
tmp.append(filename)
tmp.append(0)
filenames.append(tmp)
for filename in glob.glob('./data/PetImages/Dog/*.jpg')[2500:3000]:
tmp=[]
tmp.append(filename)
tmp.append(1)
filenames.append(tmp)
random.shuffle(filenames)
for filename in filenames:
if not num % bestnum: #超过1000,写入下一个tfrecord

训练深度网络时遇到loss恒定问题,发现是TFRecords数据集中图片和标签不一致导致。检查数据集制作过程,确保图片和标签一一对应。修正读取代码以同步加载图片和标签,解决了问题。
最低0.47元/天 解锁文章
1641

被折叠的 条评论
为什么被折叠?



