从GitHub下载了关于Tensorflow/Model的部分,经过1天的琢磨,终于将程序跑通了,下面是相关的代码,我们一起进行研究学习。
Tensorflow/Model 下代码下载地址:https://github.com/tensorflow/models
#coding:utf-8
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
# 需要两个文件,一个是pb格式的模型文件,一个为pbtxt格式的模型标签说明文件。
PATH_TO_CKPT = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb'
PATH_TO_LABELS = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/mscoco_label_map.pbtxt'
# 指定检测类别的总数
NUM_CLASSES = 90
# 初始化一个计算图
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
od_graph_def.ParseFromString(fid.read())
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
TEST_IMAGE_PATHS = ['object_detection/test_images/image5.jpg', 'object_detection/test_images/image4.jpg']
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
for image_path in TEST_IMAGE_PATHS:
image = Image.open(image_path)
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=[12, 8])
plt.imshow(image_np)
plt.show()
本文所有的操作都是以research为根目录执行的,下载的源程序可能部分路径设置不正确,自行调整即可,不再赘述。
# 需要两个文件,一个是pb格式的模型文件,一个为pbtxt格式的模型标签说明文件。
PATH_TO_CKPT = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb'
PATH_TO_LABELS = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/mscoco_label_map.pbtxt'
# 指定检测类别的总数
NUM_CLASSES = 90
PATH_TO_CKPT指定了需要加载模型的路径,这里的模型为pb格式的文件。
PATH_TO_LABELS指定了加载模型对应的标签文件。
NUM_CLASSES指定了检测类别的总数。如:一共可以检测90种物体,这里就设为90。
# 初始化一个计算图
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
od_graph_def.ParseFromString(fid.read())
tf.import_graph_def(od_graph_def, name='')
关于计算图的介绍,可以参考:https://blog.youkuaiyun.com/weixin_41874599/article/details/82663676
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef() 设置一个计算图,然后将这个计算图设为默认计算图
==========================================================================================
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
od_graph_def.ParseFromString(fid.read())
tf.import_graph_def(od_graph_def, name='')
① tf.gfile.GFile(filename, mode)
获取文本操作句柄,类似于python提供的文本操作open()函数,filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。tf.gfile.Open()是该接口的同名,可任意使用其中一个!
参考资料:https://blog.youkuaiyun.com/pursuit_zhangyu/article/details/80557958
② tf.GraphDef().ParseFromString()
将序列化消息解析为当前消息。
参考资料:Tensorflow下的帮助文档。参考下图:
这一部分其实就是将Model.ckpt中的计算图进行了结构化恢复。
================================================================================================================================================================================================
abel_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
①label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
将PATH_TO_LABELS路径中对应的pbtxt文件进行解析。
②categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
label_map_util.convert_label_map_to_categories函数对应在 源码中的解释说明:This function converts label map proto and returns a list of dicts, each of which has the following keys:
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name, e.g., 'cat', 'dog', 'pizza'.
简单来说就是:将label map proto转换成一个含字典的列表,字典的属性包括了:‘id’ ,‘name’.
注意:每个类别只能有一个id对应,如果有多个只取第一个。
③category_index = label_map_util.create_category_index(categories)
label_map_util.create_category_index(xxx)
xxx为一个列表,列表中的元素都是字典。字典的键为‘id’和‘name’
该函数返回一个字典,字典的键是数字序号,键对应的值为字典,字典就是列表中的字典。具体可以查看源代码。
def load_image_into_numpy_array(image):
image = cv2.imread(image)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
return image
# 这个函数是将图片转换成三维的numpy.array信息。无论使用PIL或者opencv都可以实现该功能。这里我使用了CV2替换。
# 不过在使用Opencv时,注意将BGR格式转换为RGB格式。
def load_image_into_numpy_array(image):
image = cv2.imread(image)
# BGR 转 RGB 方法一。 通过opencv的颜色转换函数进行转换
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
# BGR 转 RGB 方法二。 通过反切片操作进行转换
# image = image[:,:,::-1]
return image
==========================================================================================================================================================================================
# 列表,列表中存放需要进行测试的图片名称。
TEST_IMAGE_PATHS = ['object_detection/test_images/image5.jpg', 'object_detection/test_images/image4.jpg']
================================================================================================================================================================================================
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
for image_path in TEST_IMAGE_PATHS:
image_np = load_image_into_numpy_array(image_path)
image_np_expanded = np.expand_dims(image_np, axis=0)
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes),
np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index,
use_normalized_coordinates=True, line_thickness=2)
plt.figure(figsize=[12, 8])
plt.imshow(image_np)
plt.show()
将原计算图中,名称为: 'image_tensor:0'、 'detection_boxes:0'、 'detection_scores:0'、 'detection_classes:0'、 'num_detections:0' 对应的张量取出来。
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
for image_path in TEST_IMAGE_PATHS:
image_np = load_image_into_numpy_array(image_path)
image_np_expanded = np.expand_dims(image_np, axis=0)
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# visualize_boxes_and_labels_on_image_array()
# 函数的主要作用是对原图像画上回归边框
# image 三维图片,矩阵格式
# boxes 回归框信息
# classes 分类信息
# scores 分类的分数信息
# category_index 字典,字典包括了('id','name')
vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32),
np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=2)
print(np.squeeze(num))
plt.figure(figsize=[12, 8])
plt.imshow(image_np)
plt.show()