Tensorflow Object Detection API
猪检测代码以及后续进行猪分类的程序都开源在github了。
主要在官方的demo code上做了如下修改:
- 扩展det出的box,以更好地包裹目标,crop时限定不超出图像边界[expand_ratio]
- 如检测出pig, animal可能都是对的,可以依据运行结果调整接受规则,抑制检测到的概率比较大的无关类别,提高鲁棒性[class_keep]
- 使用mini batch的方式,以充分利用GPU提高程序运行效率。
下面重点看一下与obj det API有关的核心代码:
# Load a (frozen) Tensorflow model into memory
'''
tf.GraphDef():
The GraphDef class is an object created by the ProtoBuf.
详见https://www.tensorflow.org/extend/tool developers/
graph_def:
A GraphDef proto containing operations to be imported into the default graph
'''
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
'''这里用了几个util函数。
'''
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
'''
重点看定义计算图。
在这个脚本中图片是通过feed_dict={image_tensor: image_np_expanded})传递给计算图的。之前的博文介绍过如何使用自己生成的tfrecord,另外还可以使用tf1.4新出的dataset API。
关于get_tensor_by_name,就是通过名字来获得张量,具体见下面一段小测试代码。
但是还是看不出来为什么这个计算图能work,看起来就是获取了几个张量,应该就是检测框等张量依赖于image_tensor,我们去源码里确认一下。发现在object_detection/inference/detection_inference.py文件中build_inference_graph函数里,这个函数主要作用是Loads the inference graph and connects it to the input image.
具体如下:
tf.import_graph_def(
graph_def, name='', input_map={'image_tensor': image_tensor})
官方文档:input_map: A dictionary mapping input names (as strings) in graph_def to Tensor objects. The values of the named input tensors in the imported graph will be re-mapped to the respective Tensor values.
再来看看build_inference_graph函数是在哪被调用的。然后发现确实在inference文件夹下被调用了,但是我们这里通过feed的方式并不是调用这个函数。猜想一定是导出网络时定义了image_tensor这个变量名,如在object_detection/exporter.py可以看到image_tensor是placeholder,意料之中。至于计算图具体的连接关系就是模型定义本身了,下次分析训练的代码再看。
'''
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
'''注意此处省略了一些代码'''
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
import tensorflow as tf
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')
with tf.Session() as sess:
test = sess.run(e)
print (e.name) #example:0
print(test)
test = tf.get_default_graph().get_tensor_by_name("example:0")
print (test) #Tensor("example:0", shape=(2, 2), dtype=float32)
print (test.eval())
'''
输出是:
example_2:0
[[ 1. 3.]
[ 3. 7.]]
Tensor("example:0", shape=(2, 2), dtype=float32)
[[ 1. 3.]
[ 3. 7.]]
'''