参考:https://blog.youkuaiyun.com/qq_29462849/article/details/80510687
最好别看他的
训练好的模型下载地址:http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_11_06_2017.tar.gz
模型列表:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
这里里面找相关模块和文件:
https://github.com/tensorflow/models
需要下面这些东西
代码:
#-*-coding:utf-8-*-
import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as plt
import label_map_util
import visualization_utils as vis_util
import os
path=r'./mscoco_label_map.pbtxt'
label_map = label_map_util.load_labelmap(path)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with tf.gfile.FastGFile(r'frozen_inference_graph.pb',mode='rb') as f:
graph_def=tf.GraphDef() #图定义类对象
graph_def.ParseFromString(f.read()) #字符串转对象
with tf.Session() as sess:
grapha = sess.graph
tf.import_graph_def(graph_def, name='') # 对象图导入sess
for i in os.listdir('picture'):
j=os.path.join(os.getcwd(),'picture',i)
data1 = cv2.imdecode(np.fromfile(j, dtype=np.uint8), -1) #处理中文路径乱码问题
data = data1[None]
sess.run(tf.global_variables_initializer())
image_tensor = grapha.get_tensor_by_name('image_tensor:0')
boxes = grapha.get_tensor_by_name('detection_boxes:0')
scores = grapha.get_tensor_by_name('detection_scores:0')
classes = grapha.get_tensor_by_name('detection_classes:0')
num_detections = grapha.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: data})
vis_util.visualize_boxes_and_labels_on_image_array(data1,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True)#对图片的处理标识出目标
cv2.imwrite('ex'+i,data1)
效果
最后说明:
预测准确率还是有问题的,在测试过程中发现,最后一张图,用不同的读取方式,显示效果也不一样