在tensorflow上使用FCN训练自己的数据
参考文档:https://blog.youkuaiyun.com/m0_37407756/article/details/83379026 tensorflow实现FCN完成训练自己标注的数据
准备数据
- 在github上下载fcn的tensorflow版本实现
- 下载labelme对自己的图片进行标注,把生成的json文件放入你的json_file中
- 使用下面的json2png.py批量将json文件转化为可以进行训练的png格式图片,此时生成的文件夹下每个json文件夹对应img.png,lable.png,label_viz.png三张图片。img.png作为输入,label.png作为标注图像(显示为全黑,实际像素值很小)。程序中将生成的png图片转化成8位的图片存储,此时label.png中像素实际以0,1,2…来分割图像,可以将灰度值放大来进行验证。
import argparse
import json
import os
import os.path as osp
import warnings
import numpy as np
import PIL.Image
from labelme import utils
def main():
'''
usage: python json2png.py json_file
'''
parser = argparse.ArgumentParser()
parser.add_argument('json_file')
parser.add_argument('-o', '--out', default=None)
args = parser.parse_args()
json_file = args.json_file
list = os.listdir(json_file)
for i in range(0, len(list)):
path = os.path.join(json_file, list[i])
if os.path.isfile(path):
data = json.load(open(path))
img = utils.img_b64_to_arr(data['imageData'])
lbl, lbl_names = utils.labelme_shapes_to_label(img.shape, data['shapes'])
captions = ['%d: %s' % (l, name) for l, name in enumerate(lbl_names)]
lbl_viz = utils.draw_label(lbl, img, captions)
out_dir = osp.basename(list[i]).replace('.', '_')
# out_dir = osp.join(osp.dirname(list[i]), out_dir)
out_dir = osp.join('./png', out_dir)
if not osp.exists(out_dir):
os.mkdir(out_dir)
PIL.Image.fromarray(img).save(osp.join(out_dir, 'img.png'))
lbl = PIL.Image.fromarray(np.uint8(lbl))
lbl.save(osp.join(out_dir, 'label.png'))
# PIL.Image.fromarray(lbl).save(osp.join(out_dir, 'label.png'))
PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, 'label_viz.png'))
print('Saved to: %s' % out_dir)
if __name__ == '__main__':
main()
- 将生成的img和label按原数据集http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip中MIT_SceneParsing/ADEChallengeData2016文件夹中类似存储。下面是批量存入annotations/training和images/training中的代码。若自己数据集较少,验证集validation中的图片可以直接从training中剪取部分,也可以修改下面的代码。
- annotations/training和images/training中文件名字一定要相同且一一对应,比如都为train-001.png。validation中的文件名同理。
import os
import shutil
filedir = './png'
outdir_a = './FCN.tensorflow-master/Data_zoo/power/myData/annotations'
outdir_i = './FCN.tensorflow-master/Data_zoo/power/myData/images'
filelists = os.listdir(filedir)
filelists.sort()
for i,filename in enumerate(filelists):
print(filename)
filename = os.path.join(filedir, filename)
shutil.move(os.path.join(filename,'label.png'),os.path.join(outdir_a,'training'))
name1 = 'train-' + str(i+1).zfill(3) + '.png'
os.rename(os.path.join(outdir_a,'training/label.png'),os.path.join(outdir_a,'training',name1))
shutil.move(os.path.join(filename,'img.png'),os.path.join(outdir_i,'training'))
name2 = 'train-' + str(i+1).zfill(3) + '.png'
os.rename(os.path.join(outdir_i,'training/img.png'),os.path.join(outdir_i,'training',name2))
训练
- 将FCN.py中NUM_OF_CLASSESS改为自己训练数据的类别数,注意加上背景,即分类物体+1。vgg-19预训练模型在程序运行中会进行下载,也可以在训练前下载好http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat放在Model_zoo文件夹中。将flag中的data_dir改为自己的数据集所在文件(我的是Data_zoo/power/),训练时mode为train。
- read_MITSceneParsingData.py中将pickle_filename改为你数据集的名字如
"power.pickle"
,将SceneParsing_folder令为自己的文件夹SceneParsing_folder = 'myData'
,删除掉下载数据集的语句utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True)
。 - 默认batch_size大小是2,迭代次数为10000次
- 损失函数可视化:
tensorboard --logdir ./logs/train
测试
在FCN.py中将mode改为visualize,网络生成的预测图像中灰度值不为0的点,可以在原图上对应位置将其灰度值修改为某固定值如200,就完成了可视化(这只有一种分类的情况,若是多种分类需要修改成对应不同的灰度值)。FCN.py中修改如下。
elif FLAGS.mode == "visualize":
#num: the number of images to be tested which can be a single batch_size or all validation set
valid_images, valid_annotations, num = validation_dataset_reader.get_random_batch(FLAGS.batch_size)
pred = sess.run(pred_annotation, feed_dict={image: valid_images, keep_probability: 1.0})
pred = np.squeeze(pred, axis=3)
for itr in range(num):
src_img = valid_images[itr].astype(np.uint8)
pred_img = pred[itr].astype(np.uint8)
#save images to ./logs/test_visualize
utils.save_image(src_img, FLAGS.logs_dir + 'test_visualize/', name="inp_" + str(itr))
utils.save_image(pred_img, FLAGS.logs_dir + 'test_visualize/', name="pred_" + str(itr))
for i in range(pred_img.shape[0]):
for j in range(pred_img.shape[1]):
if pred_img[i,j] != 0:
#if your source images are RGB format, you need to change three channels
src_img[i,j]=200
utils.save_image(src_img, FLAGS.logs_dir + 'test_visualize/', name="visual_" + str(itr))
print("Saved image: %d" % itr)
相关文件都上传到了我的github中。