用deeplabV3+训练自己的数据集并导出.pb模型进行可视化测试

本文详细指导如何配置环境,处理VOC数据集,训练DeepLabV3+模型,从ckpt转为tflite并在Android部署,以及可视化测试过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

用deeplabV3+训练自己的数据集并导出.pb模型进行可视化测试

本文介绍deeplabv3+训练模型并做可视化测试,这里有一篇ckpt模型转tflite并在Android端部署的文章可供参考

一.环境配置和前期准备

1.下载tensorflow的deeplabV3+开源项目
2.安装CUDA10.0
3.安装cudnn7.6.5(需和CUDA的版本对应)
4.安装anaconda3
5.用anaconda3新建虚拟环境(python用3.7),教程
6.在虚拟环境中安装tensorflow-gpu
pip install tensorflow-gpu==1.15.3
7.在虚拟环境中进到第一步下载的项目的*/research/slim文件夹下删除BUILD文件,运行:

python setup.py build
python setup.py install

并添加环境变量:运行export PYTHONPATH=$PYTHONPATH:pwd:pwd/slim

二.数据集的处理

1.数据集文件夹格式:
在这里插入图片描述
其中train.txt文件为训练的图片的名称,不带后缀名
val.txt文件为验证的图片的名称,不带后缀名
trainval.txt文件为所有的图片的名称,不带后缀名
JPEGImages存放所有的RGB图片(.jpg格式)
SegmentationClass存放所有与JPEGImages对应的标签图片(.png格式,单通道)
2.在VOC下新建tfrecord文件夹
3.修改build_data.py文件的image_format的pngjpg,如下:

tf.app.flags.DEFINE_enum('image_format', 'jpg', ['jpg', 'jpeg', 'png'],
                         'Image format.')

4.修改build_voc2012_data.py

tf.app.flags.DEFINE_string('image_folder',
                           './VOC/JPEGImages',
                           'Folder containing images.')

tf.app.flags.DEFINE_string(
    'semantic_segmentation_folder',
    './VOC/SegmentationClass',
    'Folder containing semantic segmentation annotations.')

tf.app.flags.DEFINE_string(
    'list_folder',
    './VOC/ImageSets/Segmentation',
    'Folder containing lists for training and validation')

tf.app.flags.DEFINE_string(
    'output_dir',
    './VOC/tfrecord',
    'Path to save converted SSTable of TensorFlow examples.')


_NUM_SHARDS = 4	#生成tfrecord文件夹中的文件数

三.模型的训练

1.下载预训练模型
2.修改data_generator.py 文件,将自己的数据集加入到里面

_MYDATA = DatasetDescriptor(
   splits_to_sizes={
        'train': 20281,  # num of samples in images/training
        'val': 2798,  # num of samples in images/validation
        'trainval': 23079

    },
    num_classes=2,
    ignore_label=255,
)

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
    'mydata':_MYDATA,
}

3.修改train_utils.py

  exclude_list = ['global_step', 'logits']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

4.修改train.py

flags.DEFINE_string('train_logdir', "./datasets/VOC/trainout",
                    'Where the checkpoint and logs are stored.')
flags.DEFINE_integer('training_number_of_steps', 5000,
                     'The number of steps used for training')  ###########  迭代次数
flags.DEFINE_integer('train_batch_size', 12,
                     'The number of images in each batch during training.')
flags.DEFINE_list('train_crop_size', '513,513',
                  'Image crop size [height, width] during training.')
flags.DEFINE_string('tf_initial_checkpoint', "{PATH}/model.ckpt",
                   'The initial checkpoint in tensorflow format.')#预训练模型的位置
flags.DEFINE_boolean('initialize_last_layer', False,
                     'Initialize the last layer.')
flags.DEFINE_boolean('last_layers_contain_logits_only', True,
                     'Only consider logits as last layers or not.')
flags.DEFINE_string('dataset', 'mydata',
                    'Name of the segmentation dataset.')
flags.DEFINE_string('dataset_dir', "./datasets/VOC/tfrecord", 'Where the dataset reside.')

5.运行train.py
6.训练过程loss值可视化
运行tensorboard --logdir={PATH}/trainout,会出现一个网址,用Google浏览器打开给出的网址.

四.导出.pb模型

1.修改export_model.py文件

flags.DEFINE_string('checkpoint_path', "./datasets/VOC/trainout/model.ckpt", 'Checkpoint path')

flags.DEFINE_string('export_path', "./datasets/VOC/trainout/pb/frozen_inference_graph.pb",
                    'Path to output Tensorflow frozen graph.')

flags.DEFINE_integer('num_classes', 2, 'Number of classes.')

flags.DEFINE_multi_integer('crop_size', [513, 513],
                           'Crop size [height, width].')

运行export_model.py

五.可视化测试

1.在frozen_inference_graph.pb所在目录的上一级目录pb所在处,执行打包命令
tar -czf pb.tar.gz pb
2.新建并运行deeplab_demo_test.py文件:

import os
import tarfile

from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib
import tensorflow as tf

class DeepLabModel(object):
    """
    加载 DeepLab 模型;
    推断 Inference.
    """
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 513 ########和转换为.pb时的crop_size对应
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, tarball_path):
        """
        Creates and loads pretrained deeplab model.
        """
        self.graph = tf.Graph()

        graph_def = None
        # Extract frozen graph from tar archive.
        tar_file = tarfile.open(tarball_path)
        for tar_info in tar_file.getmembers():
            if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                file_handle = tar_file.extractfile(tar_info)
                graph_def = tf.GraphDef.FromString(file_handle.read())
                break

        tar_file.close()

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')

        self.sess = tf.Session(graph=self.graph)

    def run(self, image):
        """
        Runs inference on a single image.
        Args:
        image: A PIL.Image object, raw input image.
        Returns:
        resized_image: RGB image resized from original input image.
        seg_map: Segmentation map of `resized_image`.
        """
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
        batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
                                      feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
        seg_map = batch_seg_map[0]
        return resized_image, seg_map


def create_pascal_label_colormap():
    """
    Creates a label colormap used in PASCAL VOC segmentation benchmark.
    Returns:
        A Colormap for visualizing segmentation results.
    """
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)

    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
        ind >>= 3

    return colormap


def label_to_color_image(label):
    """
    Adds color defined by the dataset colormap to the label.
    Args:
        label: A 2D array with integer type, storing the segmentation label.
    Returns:
        result: A 2D array with floating type. The element of the array
        is the color indexed by the corresponding element in the input label
        to the PASCAL color map.
    Raises:
        ValueError: If label is not of rank 2 or its value is larger than color
        map maximum entry.
    """
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_pascal_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')

    return colormap[label]


def vis_segmentation(image, seg_map):
    """Visualizes input image, segmentation map and overlay view."""
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')

    plt.subplot(grid_spec[1])
    seg_image = label_to_color_image(seg_map).astype(np.uint8)
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    plt.subplot(grid_spec[2])
    plt.imshow(image)
    plt.imshow(seg_image, alpha=0.7)
    plt.axis('off')
    plt.title('segmentation overlay')

    unique_labels = np.unique(seg_map)
    ax = plt.subplot(grid_spec[3])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0)
    plt.grid('off')
    plt.show()


##
LABEL_NAMES = np.asarray(
    ['background', 'sky', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
     'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
     'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'])

FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

## Tensorflow 提供的模型下载
MODEL_NAME = 'xception_coco_voctrainval'
# ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']

_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
               'mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
               'xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
               'xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', }

# _TARBALL_NAME = 'deeplab_model.tar.gz'

# model_dir = tempfile.mkdtemp()
# tf.gfile.MakeDirs(model_dir)
#
# download_path = os.path.join(model_dir, _TARBALL_NAME)
# print('downloading model, this might take a while...')
# urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
# print('download completed! loading DeepLab model...')


download_path = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC/trainout/pb.tar.gz' #模型所在位置
#download_path = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC2012/test_model/pb_53506.tar.gz'
#download_path = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/deeplabv3_cityscapes_train/deeplabv3_mnv2_pascal_train_aug_8bit/pb.tar.gz' #模型所在位置

MODEL = DeepLabModel(download_path)
print('model loaded successfully!')


##
def run_visualization(imagefile):
    """
    DeepLab 语义分割,并可视化结果.
    """
    orignal_im = Image.open(imagefile)
    print('running deeplab on image %s...' % imagefile)
    resized_im, seg_map = MODEL.run(orignal_im)

    vis_segmentation(resized_im, seg_map)


images_dir = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC/JPEGImages'  # 测试图片目录所在位置
#images_dir = '/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC2012/JPEGImages'
images = sorted(os.listdir(images_dir))
for imgfile in images:
    run_visualization(os.path.join(images_dir, imgfile))

print('Done.')
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值