**
- Multinet中Kitticlss简介
**
(处理图像提取特征的共享编码器和实现给定任务的解码器)
MultiNet是一种高效的前馈架构, 以联合的理解语义分割, 图像分类和目标检测。在三个任务中共享一个共同的编码器, 并且具有三个分支, 每个分支是一个实现给定任务的解码器。
MultiNet可以进行端到端的训练, 所有任务的联合推理可以在100ms内完成。我们开始讨论并根据分类解码器来介绍联合编码器。
编码器的任务是处理图像并提取丰富的抽象特征[49], 该特征包含了执行准确分割, 检测和图像分类的所必要的信息。MultiNet编码器由VGG16网络的前13层组成[45],VGGnet总是被用来提取图像特征, 应用全卷积方式产生39×12×512大小的张量。这是第5个pooling层的输出, 在VGG中叫作pool5。分类解码器的设计以利用编码器的优点。为了实现这一目标, 我们应用1×1卷积,产生3912300的隐藏层, 然后用全连接层和softmax层输出最后类的概率。
**
- 跑通并学习demo
**
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Marvin Teichmann
"""
Classify an image using KittiClass.
Input: Image
Output: Image (with Cars plotted in Green)
Utilizes: Trained KittiClass weights. If no logdir is given,
pretrained weights will be downloaded and used.
Usage:
python demo.py --input data/demo.png [--output output]
[--logdir /path/to/weights] [--gpus 0]
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import sys
import collections
# configure logging
#定义日志,记录关键节点信息
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
import scipy as scp
import scipy.misc
import tensorflow as tf
flags = tf.app.flags #用于接受命令行传递参数,处理命令行参数的解析工作,(JSON)
FLAGS = flags.FLAGS #构造一个解析器FLAGS
sys.path.insert(1, 'incl')#定义搜索的优先顺序
#新添加的目录会优先于其他目录被import检查
#我的项目中没有这个文件incl,所以没有用到,见65行
try:
# Check whether setup was done correctly
import tensorvision.utils as tv_utils
import tensorvision.core as core
except ImportError:
# You forgot to initialize submodules
logging.error("Could not import the submodules.")
logging.error("Please execute:"
"'git submodule update --init --recursive'")
exit(1)
#原本程序中是None,我没有incl文件说以人为指定第二个参数
flags.DEFINE_string('logdir', 'logdir',
'Path to logdir.')
#tf.app.flags.DEFINE_string() 定义一个用于接收string 类型数值的变量
#三个参数分别是:变量名称、默认值、用法描述,
#cjh-2019-4-16
flags.DEFINE_string('input', 'DATA/demo/007007.png',
'Image to apply KittiSeg.')
flags.DEFINE_string('output', 'DATA/demo/007007_.png',
'Image to apply KittiSeg.')
default_run = 'KittiClass_postpaper' #文件夹的名字
#可能是加载欲训练网络的权重,
weights_url = ("ftp://mi.eng.cam.ac.uk/"
"pub/mttt2/models/KittiClass_postpaper.zip")
#路径为ftp://mi.eng.cam.ac.uk/pub/mttt2/models/
#自行手动下载
from PIL import Image, ImageDraw, ImageFont
#在图片上画结果
def road_draw(image, highway):
im = Image.fromarray(image.astype('uint8'))
draw = ImageDraw.Draw(im)
#fnt = ImageFont.truetype('FreeMono/FreeMonoBold.ttf', 40)
fnt = ImageFont.truetype('simhei.ttf', 40)
shape = image.shape
if highway:
draw.text((65, 10), "Highway",
font=fnt, fill=(255, 255, 0, 255))
draw.ellipse([10, 10, 55, 55], fill=(255, 255, 0, 255),
outline=(255, 255, 0, 255))
else:
draw.text((65, 10), "small road",
font=fnt, fill=(255, 0, 0, 255))
draw.ellipse([10, 10, 55, 55], fill=(255, 0, 0, 255),
outline=(255, 0, 0, 255))
return np.array(im).astype('float32')
#runs_dir为输入值,目的是构造路径,整个函数的目的就是下载权重,
# 然后解压zip文件,如果存在则不进行任何操作。
def maybe_download_and_extract(runs_dir):
logdir = os.path.join(runs_dir, default_run)#构造路径
if os.path.exists(logdir):
# weights are downloaded. Nothing to do
return
#解压文件
import zipfile
download_name = tv_utils.download(weights_url, runs_dir)#下载
logging.info("Extracting KittiSeg_pretrained.zip")
zipfile.ZipFile(download_name, 'r').extractall(runs_dir)
return
#重新定义图片的大小,第一个是原图,第二个是标记图像
def resize_label_image(image, gt_image, image_height, image_width):
image = scp.misc.imresize(image, size=(image_height, image_width),
interp='cubic')
shape = gt_image.shape
gt_image = scp.misc.imresize(gt_image, size=(image_height, image_width),
interp='nearest')
return image, gt_image
def main(_):
##设置运行代码的GPU
tv_utils.set_gpus_to_use()
if FLAGS.input is None:
logging.error("No input was given.")
logging.info(
"Usage: python demo.py --input data/test.png "
"[--output output] [--logdir /path/to/weights] "
"[--gpus GPUs_to_use] ")
exit(1)
if FLAGS.logdir is None:
# Download and use weights from the MultiNet Paper
if 'TV_DIR_RUNS' in os.environ:
runs_dir = os.path.join(os.environ['TV_DIR_RUNS'],
'KittiClass')
else:
runs_dir = 'RUNS'
maybe_download_and_extract(runs_dir)
logdir = os.path.join(runs_dir, default_run)
else:
logging.info("Using weights found in {}".format(FLAGS.logdir))
logdir = FLAGS.logdir
#这段在做的就是找到权重的存放路径并赋值给logdir
# Loading hyperparameters from logdir,从下载的权重的存放路径加载超参
hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes')
logging.info("Hypes loaded successfully.")
# Loading tv modules (encoder.py, decoder.py, eval.py) from logdir
modules = tv_utils.load_modules_from_logdir(logdir)
logging.info("Modules loaded successfully. Starting to build tf graph.")
# Create tf graph and build module.
with tf.Graph().as_default():
# Create placeholder for input
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)
# build Tensorflow graph using the model from logdir
prediction = core.build_inference_graph(hypes, modules,
image=image)
logging.info("Graph build successfully.")
# Create a session for running Ops on the Graph.
sess = tf.Session()
saver = tf.train.Saver()#模型保存,先要创建一个saver对象
# Load weights from logdir
#cjh-2019-4-16
# logdir=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass_postpaper'
logdir = r'E:/lixueqian/2019/new_method/MultiNet-master1/submodules/KittiClass/KittiClass_postpaper'
core.load_weights(logdir, sess, saver)
logging.info("Weights loaded successfully.")
input = FLAGS.input
logging.info("Starting inference using {} as input".format(input)) #使用的方法
####encoder.inference 好像是构建完整的VGG16NET
# Load and resize input image
# input=r'/home/xue/MultiNet-master/submodules/KittiClass/KittiClass-master/DATA/demo/007034.png'
image = scp.misc.imread(input)
if hypes['jitter']['reseize_image']:
# Resize input only, if specified in hypes
image_height = hypes['jitter']['image_height']
image_width = hypes['jitter']['image_width']
image = scp.misc.imresize(image, size=(image_height, image_width),
interp='cubic')
# Run KittiSeg model on image
feed = {image_pl: image} #为placeholder赋值
softmax_road, _ = prediction['softmax']#
output = sess.run([softmax_road], feed_dict=feed)
# output是分类的概率
# Get predicted class
highway = (np.argmax(output[0][0]) == 0)
# highway是将output转化为bool值
# Draw resulting output image
new_img = road_draw(image, highway)
# Save output images to disk.
if FLAGS.output is None:
output_base_name = input
out_image_name = output_base_name.split('.')[0] + '_out.png'
else:
out_image_name = FLAGS.output
scp.misc.imsave(out_image_name, new_img)
logging.info("")
logging.info("Output image has been saved to: {}".format(
os.path.realpath(out_image_name)))
if __name__ == '__main__':
tf.app.run() #使用flags又解析了一次,tf.app.run的作用仅仅是指定main主函数和使用flags再解析一次输入
结果展示:
-train源码学习
(与demo中重复的不在赘述)
- 预定义flags:
包括name、project、hypes、mod以及bool类型的model,其中hypes是存储模型参数的,hypes\KittiClass.json - 定义dict_merge函数
def dict_merge(dct, merge_dct):
功能:递归字典合并。dict_merge会向下递归到任意深度的嵌套的dict中去表层更新键而不是仅更新顶级键,dict_merge会合并到dct中
函数学习:(1)isinstance(object, classinfo) 如果参数object是classinfo的实例,或者object是classinfo类的子类的一个实例, 返回True
def dict_merge(dct, merge_dct):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
:param dct: dict onto which the merge is executed
:param merge_dct: dct merged into dct
:return: None
"""
for k, v in merge_dct.iteritems():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collections.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
- 主函数
设置GPU,导入tensorvision.train和tensorflow_fcn.utils,将tf.app.flags.FLAGS.hypes打开用json.load直接读取,导入KittiBox.json成功,utils.load_plugins()功能是“Load all TensorVision plugins(加载所有TensorVision插件)."
if tf.app.flags.FLAGS.mod is not None:
import ast
mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)#将
dict_merge(hypes, mod_dict)
以下两者等价:eval() 函数用来执行一个字符串表达式,并返回表达式的值。在这里就是将string 转化为dict 并且验证执行合法的类型
import ast
my_list = ast.literal_eval(expr)
expr = "[1, 2, 3]"
my_list = eval(expr)
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
#在路径为RUNS/kittiBox../model_files下面添加文件,创建一个image文件夹和一个输出日志文件
utils._add_paths_to_sys(hypes)
logging.info("Initialize training folder")
train.initialize_training_folder(hypes)
#在路径为RUNS/kittiBox../model_files下面添加文件,创建一个image文件夹和一个输出日志文件
就是在做这部分吧。
‘’’’’
RUNS/KittiBox_date_/images :
RUNS/KittiBox_date_/model_files :
architecure.py
data_input.py
eval.py
hypes.py
objective.py
solver.py
RUNS/KittiBox_date_/output.log
‘’’’’
train.maybe_download_and_extract(hypes)下载解压zip文件,如果存在则不进行任何操作,开始训练。hypes是hypes.json。
哒哒哒转到tensorvision中的train.py中的do_training,这个函数包含了训练模型的一些步骤。
- do_training
def do_training(hypes):
"""
Train model for a number of steps.
This trains the model for at most hypes['solver']['max_steps'].
It shows an update every utils.cfg.step_show steps and writes
the model to hypes['dirs']['output_dir'] every utils.cfg.step_eval
steps.
Paramters
---------
hypes : dict
Hyperparameters
"""
# Get the sets of images and labels for training, validation, and
# test on MNIST.
#加载几个训练相关的.py文件
modules = utils.load_modules_from_hypes(hypes)
'''''
[‘input’] : ../inputs/Kitti_input.py
[‘arch’] : ../encoder/vgg.py
[‘objective’] : ../decoder/fastBox.py
[‘solver’] : ../optimizer/generic_optimizer.py
[‘eval’] : ../evals/kitti_eval.py
'''''
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Session() as sess:
# build the graph based on the loaded modules
with tf.name_scope("Queues"):
queue = modules['input'].create_queues(hypes, 'train')
tv_graph = core.build_training_graph(hypes, queue, modules)
# prepaire the tv session
tv_sess = core.start_tv_session(hypes)
with tf.name_scope('Validation'):
tf.get_variable_scope().reuse_variables()
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)
image.set_shape([1, None, None, 3])
inf_out = core.build_inference_graph(hypes, modules,
image=image)
tv_graph['image_pl'] = image_pl
tv_graph['inf_out'] = inf_out
# Start the data load
modules['input'].start_enqueuing_threads(hypes, queue, 'train', sess)
# And then after everything is built, start the training loop.
run_training(hypes, modules, tv_graph, tv_sess)
# stopping input Threads
tv_sess['coord'].request_stop()
tv_sess['coord'].join(tv_sess['threads'])
modules = utils.load_modules_from_hypes(hypes)功能是从hypes指定的文件中加载模型。 即加载的模块是: input_file, architecture_file, objective_file, optimizer_file。
with tf.name_scope("Queues"):
queue = modules['input'].create_queues(hypes, 'train')
根据加载的模型创建Graph
initialize_training_folder(hypes)
maybe_download_and_extract(hypes)
开始初始化训练文件夹,训练的时候的RUN文件夹,包括image,model_file,output.log等。进一步保证hypes存在。