用deeplabV3+训练自己的数据集并导出.pb模型进行可视化测试
本文介绍deeplabv3+训练模型并做可视化测试,这里有一篇ckpt模型转tflite并在Android端部署的文章可供参考
一.环境配置和前期准备
1.下载tensorflow的deeplabV3+开源项目
2.安装CUDA10.0
3.安装cudnn7.6.5(需和CUDA的版本对应)
4.安装anaconda3
5.用anaconda3新建虚拟环境(python用3.7),教程
6.在虚拟环境中安装tensorflow-gpu
pip install tensorflow-gpu==1.15.3
7.在虚拟环境中进到第一步下载的项目的*/research/slim文件夹下删除BUILD
文件,运行:
python setup.py build
python setup.py install
并添加环境变量:运行export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim
二.数据集的处理
1.数据集文件夹格式:
其中train.txt文件为训练的图片的名称,不带后缀名
val.txt文件为验证的图片的名称,不带后缀名
trainval.txt文件为所有的图片的名称,不带后缀名
JPEGImages存放所有的RGB图片(.jpg格式)
SegmentationClass存放所有与JPEGImages对应的标签图片(.png格式,单通道)
2.在VOC下新建tfrecord文件夹
3.修改build_data.py文件的image_format的png
为jpg
,如下:
tf.app.flags.DEFINE_enum('image_format', 'jpg', ['jpg', 'jpeg', 'png'],
'Image format.')
4.修改build_voc2012_data.py
tf.app.flags.DEFINE_string('image_folder',
'./VOC/JPEGImages',
'Folder containing images.')
tf.app.flags.DEFINE_string(
'semantic_segmentation_folder',
'./VOC/SegmentationClass',
'Folder containing semantic segmentation annotations.')
tf.app.flags.DEFINE_string(
'list_folder',
'./VOC/ImageSets/Segmentation',
'Folder containing lists for training and validation')
tf.app.flags.DEFINE_string(
'output_dir',
'./VOC/tfrecord',
'Path to save converted SSTable of TensorFlow examples.')
_NUM_SHARDS = 4 #生成tfrecord文件夹中的文件数
三.模型的训练
1.下载预训练模型
2.修改data_generator.py 文件,将自己的数据集加入到里面
_MYDATA = DatasetDescriptor(
splits_to_sizes={
'train': 20281, # num of samples in images/training
'val': 2798, # num of samples in images/validation
'trainval': 23079
},
num_classes=2,
ignore_label=255,
)
_DATASETS_INFORMATION = {
'cityscapes': _CITYSCAPES_INFORMATION,
'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
'ade20k': _ADE20K_INFORMATION,
'mydata':_MYDATA,
}
3.修改train_utils.py
exclude_list = ['global_step', 'logits']
if not initialize_last_layer:
exclude_list.extend(last_layers)
4.修改train.py
flags.DEFINE_string('train_logdir', "./datasets/VOC/trainout",
'Where the checkpoint and logs are stored.')
flags.DEFINE_integer('training_number_of_steps', 5000,
'The number of steps used for training') ########### 迭代次数
flags.DEFINE_integer('train_batch_size', 12,
'The number of images in each batch during training.')
flags.DEFINE_list('train_crop_size', '513,513',
'Image crop size [height, width] during training.')
flags.DEFINE_string('tf_initial_checkpoint', "{PATH}/model.ckpt",
'The initial checkpoint in tensorflow format.')#预训练模型的位置
flags.DEFINE_boolean('initialize_last_layer', False,
'Initialize the last layer.')
flags.DEFINE_boolean('last_layers_contain_logits_only', True,
'Only consider logits as last layers or not.')
flags.DEFINE_string('dataset', 'mydata',
'Name of the segmentation dataset.')
flags.DEFINE_string('dataset_dir', "./datasets/VOC/tfrecord", 'Where the dataset reside.')
5.运行train.py
6.训练过程loss值可视化
运行tensorboard --logdir={PATH}/trainout
,会出现一个网址,用Google浏览器打开给出的网址.
四.导出.pb模型
1.修改export_model.py文件
flags.DEFINE_string('checkpoint_path', "./datasets/VOC/trainout/model.ckpt", 'Checkpoint path')
flags.DEFINE_string('export_path', "./datasets/VOC/trainout/pb/frozen_inference_graph.pb",
'Path to output Tensorflow frozen graph.')
flags.DEFINE_integer('num_classes', 2, 'Number of classes.')
flags.DEFINE_multi_integer('crop_size', [513, 513],
'Crop size [height, width].')
运行export_model.py
五.可视化测试
1.在frozen_inference_graph.pb所在目录的上一级目录pb所在处,执行打包命令
tar -czf pb.tar.gz pb
2.新建并运行deeplab_demo_test.py文件:
import os
import tarfile
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib
import tensorflow as tf
class DeepLabModel(object):
"""
加载 DeepLab 模型;
推断 Inference.
"""
INPUT_TENSOR_NAME = 'ImageTensor:0'
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
INPUT_SIZE = 513 ########和转换为.pb时的crop_size对应
FROZEN_GRAPH_NAME = 'frozen_inference_graph'
def __init__(self, tarball_path):
"""
Creates and loads pretrained deeplab model.
"""
self.graph = tf.Graph()
graph_def = None
# Extract frozen graph from tar archive.
tar_file = tarfile.open(tarball_path)
for tar_info in tar_file.getmembers():
if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
file_handle = tar_file.extractfile(tar_info)
graph_def = tf.GraphDef.FromString(file_handle.read())
break
tar_file.close()
if graph_def is None:
raise RuntimeError('Cannot find inference graph in tar archive.')
with self.graph.as_default():
tf.import_graph_def(graph_def, name='')
self.sess = tf.Session(graph=self.graph)
def run(self, image):
"""
Runs inference on a single image.
Args:
image: A PIL.Image object, raw input image.
Returns:
resized_image: RGB image resized from original input image.
seg_map: Segmentation map of `resized_image`.
"""
width, height = image.size
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
seg_map = batch_seg_map[0]
return resized_image, seg_map
def create_pascal_label_colormap():
"""
Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.zeros((256, 3), dtype=int)
ind = np.arange(256, dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= ((ind >> channel) & 1) << shift
ind >>= 3
return colormap
def label_to_color_image(label):
"""
Adds color defined by the dataset colormap to the label.
Args:
label: A 2D array with integer type, storing the segmentation label.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label')
colormap = create_pascal_label_colormap()
if np.max(label) >= len(colormap):
raise ValueError('label value too large.')
return colormap[label]
def vis_segmentation(image, seg_map):
"""Visualizes input image, segmentation map and overlay view."""
plt.figure(figsize=(15, 5))
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
plt.subplot(grid_spec[0])
plt.imshow(image)
plt.axis('off')
plt.title('input image')
plt.subplot(grid_spec[1])
seg_image = label_to_color_image(seg_map).astype(np.uint8)
plt.imshow(seg_image)
plt.axis('off')
plt.title('segmentation map')
plt.subplot(grid_spec[2])
plt.imshow(image)
plt.imshow(seg_image, alpha=0.7)
plt.axis('off')
plt.title('segmentation overlay')
unique_labels = np.unique(seg_map)
ax = plt.subplot(grid_spec[3])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0)
plt.grid('off')
plt.show()
##
LABEL_NAMES = np.asarray(
['background', 'sky', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'])
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
## Tensorflow 提供的模型下载
MODEL_NAME = 'xception_coco_voctrainval'
# ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']
_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
'mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
'xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
'xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', }
# _TARBALL_NAME = 'deeplab_model.tar.gz'
# model_dir = tempfile.mkdtemp()
# tf.gfile.MakeDirs(model_dir)
#
# download_path = os.path.join(model_dir, _TARBALL_NAME)
# print('downloading model, this might take a while...')
# urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
# print('download completed! loading DeepLab model...')
download_path = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC/trainout/pb.tar.gz' #模型所在位置
#download_path = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC2012/test_model/pb_53506.tar.gz'
#download_path = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/deeplabv3_cityscapes_train/deeplabv3_mnv2_pascal_train_aug_8bit/pb.tar.gz' #模型所在位置
MODEL = DeepLabModel(download_path)
print('model loaded successfully!')
##
def run_visualization(imagefile):
"""
DeepLab 语义分割,并可视化结果.
"""
orignal_im = Image.open(imagefile)
print('running deeplab on image %s...' % imagefile)
resized_im, seg_map = MODEL.run(orignal_im)
vis_segmentation(resized_im, seg_map)
images_dir = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC/JPEGImages' # 测试图片目录所在位置
#images_dir = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC2012/JPEGImages'
images = sorted(os.listdir(images_dir))
for imgfile in images:
run_visualization(os.path.join(images_dir, imgfile))
print('Done.')