Kitticlass训练自己的数据集

MultiNet是一种高效前馈架构,可同时处理语义分割、图像分类和目标检测,通过共享编码器和特定任务解码器实现多任务联合训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

**

- Multinet中Kitticlss简介

**
Multinet架构
(处理图像提取特征的共享编码器和实现给定任务的解码器)
MultiNet是一种高效的前馈架构, 以联合的理解语义分割, 图像分类和目标检测。在三个任务中共享一个共同的编码器, 并且具有三个分支, 每个分支是一个实现给定任务的解码器。
MultiNet可以进行端到端的训练, 所有任务的联合推理可以在100ms内完成。我们开始讨论并根据分类解码器来介绍联合编码器。
编码器的任务是处理图像并提取丰富的抽象特征[49], 该特征包含了执行准确分割, 检测和图像分类的所必要的信息。MultiNet编码器由VGG16网络的前13层组成[45],VGGnet总是被用来提取图像特征, 应用全卷积方式产生39×12×512大小的张量。这是第5个pooling层的输出, 在VGG中叫作pool5。分类解码器的设计以利用编码器的优点。为了实现这一目标, 我们应用1×1卷积,产生3912300的隐藏层, 然后用全连接层和softmax层输出最后类的概率。
**

- 跑通并学习demo

**

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Marvin Teichmann


"""
Classify an image using KittiClass.

Input: Image
Output: Image (with Cars plotted in Green)

Utilizes: Trained KittiClass weights. If no logdir is given,
pretrained weights will be downloaded and used.

Usage:
python demo.py --input data/demo.png [--output output]
                [--logdir /path/to/weights] [--gpus 0]


"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import logging
import os
import sys

import collections

# configure logging
#定义日志,记录关键节点信息
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logging.INFO,
                    stream=sys.stdout)

# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
import scipy as scp
import scipy.misc
import tensorflow as tf


flags = tf.app.flags #用于接受命令行传递参数,处理命令行参数的解析工作,(JSON)
FLAGS = flags.FLAGS #构造一个解析器FLAGS

sys.path.insert(1, 'incl')#定义搜索的优先顺序
#新添加的目录会优先于其他目录被import检查
#我的项目中没有这个文件incl,所以没有用到,见65行
try:
    # Check whether setup was done correctly

    import tensorvision.utils as tv_utils
    import tensorvision.core as core
except ImportError:
    # You forgot to initialize submodules
    logging.error("Could not import the submodules.")
    logging.error("Please execute:"
                  "'git submodule update --init --recursive'")
    exit(1)

#原本程序中是None,我没有incl文件说以人为指定第二个参数
flags.DEFINE_string('logdir', 'logdir',
                    'Path to logdir.')
#tf.app.flags.DEFINE_string() 定义一个用于接收string 类型数值的变量
#三个参数分别是:变量名称、默认值、用法描述,
#cjh-2019-4-16
flags.DEFINE_string('input', 'DATA/demo/007007.png',
                    'Image to apply KittiSeg.')
flags.DEFINE_string('output', 'DATA/demo/007007_.png',
                    'Image to apply KittiSeg.')


default_run = 'KittiClass_postpaper' #文件夹的名字
#可能是加载欲训练网络的权重,
weights_url = ("ftp://mi.eng.cam.ac.uk/"
               "pub/mttt2/models/KittiClass_postpaper.zip")
#路径为ftp://mi.eng.cam.ac.uk/pub/mttt2/models/
#自行手动下载

from PIL import Image, ImageDraw, ImageFont

#在图片上画结果
def road_draw(image, highway):
    im = Image.fromarray(image.astype('uint8'))
    draw = ImageDraw.Draw(im)

    #fnt = ImageFont.truetype('FreeMono/FreeMonoBold.ttf', 40)
    fnt = ImageFont.truetype('simhei.ttf', 40)
    shape = image.shape

    if highway:
        draw.text((65, 10), "Highway",
                  font=fnt, fill=(255, 255, 0, 255))

        draw.ellipse([10, 10, 55, 55], fill=(255, 255, 0, 255),
                     outline=(255, 255, 0, 255))
    else:
        draw.text((65, 10), "small road",
                  font=fnt, fill=(255, 0, 0, 255))

        draw.ellipse([10, 10, 55, 55], fill=(255, 0, 0, 255),
                     outline=(255, 0, 0, 255))

    return np.array(im).astype('float32')

#runs_dir为输入值,目的是构造路径,整个函数的目的就是下载权重,
# 然后解压zip文件,如果存在则不进行任何操作。
def maybe_download_and_extract(runs_dir):
    logdir = os.path.join(runs_dir, default_run)#构造路径

    if os.path.exists(logdir):
        # weights are downloaded. Nothing to do
        return
    #解压文件
    import zipfile
    download_name = tv_utils.download(weights_url, runs_dir)#下载

    logging.info("Extracting KittiSeg_pretrained.zip")

    zipfile.ZipFile(download_name, 'r').extractall(runs_dir)

    return

#重新定义图片的大小,第一个是原图,第二个是标记图像
def resize_label_image(image, gt_image, image_height, image_width):
    image = scp.misc.imresize(image, size=(image_height, image_width),
                              interp='cubic')
    shape = gt_image.shape
    gt_image = scp.misc.imresize(gt_image, size=(image_height, image_width),
                                 interp='nearest')

    return image, gt_image


def main(_):
    ##设置运行代码的GPU
    tv_utils.set_gpus_to_use()
    if FLAGS.input is None:
        logging.error("No input was given.")
        logging.info(
            "Usage: python demo.py --input data/test.png "
            "[--output output] [--logdir /path/to/weights] "
            "[--gpus GPUs_to_use] ")
        exit(1)

    if FLAGS.logdir is None:
        # Download and use weights from the MultiNet Paper
        if 'TV_DIR_RUNS' in os.environ:
            runs_dir = os.path.join(os.environ['TV_DIR_RUNS'],
                                    'KittiClass')
        else:
            runs_dir = 'RUNS'
        maybe_download_and_extract(runs_dir)
        logdir = os.path.join(runs_dir, default_run)
    else:
        logging.info("Using weights found in {}".format(FLAGS.logdir))
        logdir = FLAGS.logdir
    #这段在做的就是找到权重的存放路径并赋值给logdir
    # Loading hyperparameters from logdir,从下载的权重的存放路径加载超参
    hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes')

    logging.info("Hypes loaded successfully.")

    # Loading tv modules (encoder.py, decoder.py, eval.py) from logdir
    modules = tv_utils.load_modules_from_logdir(logdir)
    logging.info("Modules loaded successfully. Starting to build tf graph.")

    # Create tf graph and build module.
    with tf.Graph().as_default():
        # Create placeholder for input
        image_pl = tf.placeholder(tf.float32)
        image = tf.expand_dims(image_pl, 0)

        # build Tensorflow graph using the model from logdir
        prediction = core.build_inference_graph(hypes, modules,
                                                image=image)

        logging.info("Graph build successfully.")

        # Create a session for running Ops on the Graph.
        sess = tf.Session()
        saver = tf.train.Saver()#模型保存,先要创建一个saver对象

        # Load weights from logdir
        #cjh-2019-4-16
        # logdir=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass_postpaper'
        logdir = r'E:/lixueqian/2019/new_method/MultiNet-master1/submodules/KittiClass/KittiClass_postpaper'
        core.load_weights(logdir, sess, saver)

        logging.info("Weights loaded successfully.")

    input = FLAGS.input
    logging.info("Starting inference using {} as input".format(input))  #使用的方法
    ####encoder.inference 好像是构建完整的VGG16NET
    # Load and resize input image
    # input=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass-master/DATA/demo/007034.png'
    image = scp.misc.imread(input)
    if hypes['jitter']['reseize_image']:
        # Resize input only, if specified in hypes
        image_height = hypes['jitter']['image_height']
        image_width = hypes['jitter']['image_width']
        image = scp.misc.imresize(image, size=(image_height, image_width),
                                  interp='cubic')

    # Run KittiSeg model on image
    feed = {image_pl: image} #为placeholder赋值
    softmax_road, _ = prediction['softmax']#
    output = sess.run([softmax_road], feed_dict=feed)
    # output是分类的概率
    # Get predicted class
    highway = (np.argmax(output[0][0]) == 0)
    # highway是将output转化为bool值
    # Draw resulting output image
    new_img = road_draw(image, highway)

    # Save output images to disk.
    if FLAGS.output is None:
        output_base_name = input
        out_image_name = output_base_name.split('.')[0] + '_out.png'
    else:
        out_image_name = FLAGS.output

    scp.misc.imsave(out_image_name, new_img)

    logging.info("")
    logging.info("Output image has been saved to: {}".format(
        os.path.realpath(out_image_name)))

if __name__ == '__main__':
    tf.app.run()  #使用flags又解析了一次,tf.app.run的作用仅仅是指定main主函数和使用flags再解析一次输入

结果展示:在这里插入图片描述
在这里插入图片描述

-train源码学习

(与demo中重复的不在赘述)

  • 预定义flags:
    包括name、project、hypes、mod以及bool类型的model,其中hypes是存储模型参数的,hypes\KittiClass.json
  • 定义dict_merge函数
    def dict_merge(dct, merge_dct):
    功能:递归字典合并。dict_merge会向下递归到任意深度的嵌套的dict中去表层更新键而不是仅更新顶级键,dict_merge会合并到dct中
    函数学习:(1)isinstance(object, classinfo) 如果参数object是classinfo的实例,或者object是classinfo类的子类的一个实例, 返回True
def dict_merge(dct, merge_dct):
    """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
    updating only top-level keys, dict_merge recurses down into dicts nested
    to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
    ``dct``.
    :param dct: dict onto which the merge is executed
    :param merge_dct: dct merged into dct
    :return: None
    """
    for k, v in merge_dct.iteritems():
        if (k in dct and isinstance(dct[k], dict) and
                isinstance(merge_dct[k], collections.Mapping)):
            dict_merge(dct[k], merge_dct[k])
        else:
            dct[k] = merge_dct[k]
  • 主函数
    设置GPU,导入tensorvision.train和tensorflow_fcn.utils,将tf.app.flags.FLAGS.hypes打开用json.load直接读取,导入KittiBox.json成功,utils.load_plugins()功能是“Load all TensorVision plugins(加载所有TensorVision插件)."
if tf.app.flags.FLAGS.mod is not None:
    import ast
    mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)#将
    dict_merge(hypes, mod_dict)

以下两者等价:eval() 函数用来执行一个字符串表达式,并返回表达式的值。在这里就是将string 转化为dict 并且验证执行合法的类型

import ast 
my_list = ast.literal_eval(expr)
expr = "[1, 2, 3]" 
my_list = eval(expr)
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
    #在路径为RUNS/kittiBox../model_files下面添加文件,创建一个image文件夹和一个输出日志文件
    utils._add_paths_to_sys(hypes)

    logging.info("Initialize training folder")
    train.initialize_training_folder(hypes)
    #在路径为RUNS/kittiBox../model_files下面添加文件,创建一个image文件夹和一个输出日志文件

就是在做这部分吧。
‘’’’’
RUNS/KittiBox_date_/images :
RUNS/KittiBox_date_/model_files :
architecure.py
data_input.py
eval.py
hypes.py
objective.py
solver.py
RUNS/KittiBox_date_/output.log
‘’’’’
train.maybe_download_and_extract(hypes)下载解压zip文件,如果存在则不进行任何操作,开始训练。hypes是hypes.json。
哒哒哒转到tensorvision中的train.py中的do_training,这个函数包含了训练模型的一些步骤。

  • do_training
def do_training(hypes):
    """
    Train model for a number of steps.

    This trains the model for at most hypes['solver']['max_steps'].
    It shows an update every utils.cfg.step_show steps and writes
    the model to hypes['dirs']['output_dir'] every utils.cfg.step_eval
    steps.

    Paramters
    ---------
    hypes : dict
        Hyperparameters
    """
    # Get the sets of images and labels for training, validation, and
    # test on MNIST.
    #加载几个训练相关的.py文件
    modules = utils.load_modules_from_hypes(hypes)

'''''
[‘input’] : ../inputs/Kitti_input.py
[‘arch’] : ../encoder/vgg.py
[‘objective’] : ../decoder/fastBox.py
[‘solver’] : ../optimizer/generic_optimizer.py
[‘eval’] : ../evals/kitti_eval.py
'''''
    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Session() as sess:

        # build the graph based on the loaded modules
        with tf.name_scope("Queues"):
            queue = modules['input'].create_queues(hypes, 'train')

        tv_graph = core.build_training_graph(hypes, queue, modules)

        # prepaire the tv session
        tv_sess = core.start_tv_session(hypes)

        with tf.name_scope('Validation'):
            tf.get_variable_scope().reuse_variables()
            image_pl = tf.placeholder(tf.float32)
            image = tf.expand_dims(image_pl, 0)
            image.set_shape([1, None, None, 3])
            inf_out = core.build_inference_graph(hypes, modules,
                                                 image=image)
            tv_graph['image_pl'] = image_pl
            tv_graph['inf_out'] = inf_out

        # Start the data load
        modules['input'].start_enqueuing_threads(hypes, queue, 'train', sess)

        # And then after everything is built, start the training loop.
        run_training(hypes, modules, tv_graph, tv_sess)

        # stopping input Threads
        tv_sess['coord'].request_stop()
        tv_sess['coord'].join(tv_sess['threads'])

modules = utils.load_modules_from_hypes(hypes)功能是从hypes指定的文件中加载模型。 即加载的模块是: input_file, architecture_file, objective_file, optimizer_file。

 with tf.name_scope("Queues"):
            queue = modules['input'].create_queues(hypes, 'train')

根据加载的模型创建Graph

initialize_training_folder(hypes)
maybe_download_and_extract(hypes)

开始初始化训练文件夹,训练的时候的RUN文件夹,包括image,model_file,output.log等。进一步保证hypes存在。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值