TensorFlow 目标检测实例解析

本文详细介绍了如何使用TensorFlow预训练模型进行物体检测,包括模型下载、环境配置、代码解析及运行过程。通过实例演示了从图片输入到检测结果可视化全过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        从GitHub下载了关于Tensorflow/Model的部分,经过1天的琢磨,终于将程序跑通了,下面是相关的代码,我们一起进行研究学习。

Tensorflow/Model 下代码下载地址:https://github.com/tensorflow/models

#coding:utf-8
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

# 需要两个文件,一个是pb格式的模型文件,一个为pbtxt格式的模型标签说明文件。
PATH_TO_CKPT = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb'
PATH_TO_LABELS = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/mscoco_label_map.pbtxt'
# 指定检测类别的总数
NUM_CLASSES = 90

# 初始化一个计算图
detection_graph = tf.Graph()
with detection_graph.as_default():
	od_graph_def = tf.GraphDef()
	with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
		od_graph_def.ParseFromString(fid.read())
		tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

def load_image_into_numpy_array(image):
	(im_width, im_height) = image.size
	return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
	
TEST_IMAGE_PATHS = ['object_detection/test_images/image5.jpg', 'object_detection/test_images/image4.jpg']

with detection_graph.as_default():
	with tf.Session(graph=detection_graph) as sess:
	    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
	    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
	    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
	    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
	    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
	    
	    for image_path in TEST_IMAGE_PATHS:
	    	image = Image.open(image_path)
	    	image_np = load_image_into_numpy_array(image)
	    	image_np_expanded = np.expand_dims(image_np, axis=0)
	    	(boxes, scores, classes, num) = sess.run(
	    		[detection_boxes, detection_scores, detection_classes, num_detections], 
	    		feed_dict={image_tensor: image_np_expanded})
	    	
	    	vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8)
	    	plt.figure(figsize=[12, 8])
	    	plt.imshow(image_np)
	    	plt.show()

        本文所有的操作都是以research为根目录执行的,下载的源程序可能部分路径设置不正确,自行调整即可,不再赘述。

# 需要两个文件,一个是pb格式的模型文件,一个为pbtxt格式的模型标签说明文件。
PATH_TO_CKPT = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb'
PATH_TO_LABELS = 'object_detection/ssd_mobilenet_v1_coco_2017_11_17/mscoco_label_map.pbtxt'
# 指定检测类别的总数
NUM_CLASSES = 90

        PATH_TO_CKPT指定了需要加载模型的路径,这里的模型为pb格式的文件。

        PATH_TO_LABELS指定了加载模型对应的标签文件。

        NUM_CLASSES指定了检测类别的总数。如:一共可以检测90种物体,这里就设为90。

# 初始化一个计算图
detection_graph = tf.Graph()
with detection_graph.as_default():
	od_graph_def = tf.GraphDef()
	with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
		od_graph_def.ParseFromString(fid.read())
		tf.import_graph_def(od_graph_def, name='')

       关于计算图的介绍,可以参考:https://blog.youkuaiyun.com/weixin_41874599/article/details/82663676

         detection_graph = tf.Graph()       

         with detection_graph.as_default():       

                od_graph_def = tf.GraphDef()         设置一个计算图,然后将这个计算图设为默认计算图

==========================================================================================

with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        od_graph_def.ParseFromString(fid.read())
        tf.import_graph_def(od_graph_def, name='')

① tf.gfile.GFile(filename, mode)

       获取文本操作句柄,类似于python提供的文本操作open()函数,filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。tf.gfile.Open()是该接口的同名,可任意使用其中一个!

参考资料:https://blog.youkuaiyun.com/pursuit_zhangyu/article/details/80557958

 

② tf.GraphDef().ParseFromString()

       将序列化消息解析为当前消息。

参考资料:Tensorflow下的帮助文档。参考下图:

这一部分其实就是将Model.ckpt中的计算图进行了结构化恢复。

================================================================================================================================================================================================

abel_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

①label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
        将PATH_TO_LABELS路径中对应的pbtxt文件进行解析。

 

②categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
        label_map_util.convert_label_map_to_categories函数对应在 源码中的解释说明:This function converts label map proto and returns a list of dicts, each of which  has the following keys:

    'id': (required) an integer id uniquely identifying this category.
    'name': (required) string representing category name, e.g., 'cat', 'dog', 'pizza'.

     简单来说就是:将label map proto转换成一个含字典的列表,字典的属性包括了:‘id’ ,‘name’.

     注意:每个类别只能有一个id对应,如果有多个只取第一个。

 

③category_index = label_map_util.create_category_index(categories)
label_map_util.create_category_index(xxx)

xxx为一个列表,列表中的元素都是字典。字典的键为‘id’和‘name’
该函数返回一个字典,字典的键是数字序号,键对应的值为字典,字典就是列表中的字典。具体可以查看源代码。


def load_image_into_numpy_array(image):
	image = cv2.imread(image)
	image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
	return image

# 这个函数是将图片转换成三维的numpy.array信息。无论使用PIL或者opencv都可以实现该功能。这里我使用了CV2替换。    
# 不过在使用Opencv时,注意将BGR格式转换为RGB格式。
def load_image_into_numpy_array(image):
    image = cv2.imread(image)
    # BGR 转 RGB 方法一。    通过opencv的颜色转换函数进行转换
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    # BGR 转 RGB 方法二。    通过反切片操作进行转换
    # image = image[:,:,::-1]
    return image

==========================================================================================================================================================================================

# 列表,列表中存放需要进行测试的图片名称。
TEST_IMAGE_PATHS = ['object_detection/test_images/image5.jpg', 'object_detection/test_images/image4.jpg']

================================================================================================================================================================================================

with detection_graph.as_default():
	with tf.Session(graph=detection_graph) as sess:
	    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
	    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
	    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
	    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
	    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
	    
	    for image_path in TEST_IMAGE_PATHS:
	    	image_np = load_image_into_numpy_array(image_path)
	    	image_np_expanded = np.expand_dims(image_np, axis=0)
	    	(boxes, scores, classes, num) = sess.run(
	    		[detection_boxes, detection_scores, detection_classes, num_detections], 
	    		feed_dict={image_tensor: image_np_expanded})
	    	vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes),         
            np.squeeze(classes).astype(np.int32),  np.squeeze(scores), category_index, 
            use_normalized_coordinates=True, line_thickness=2)
	    	plt.figure(figsize=[12, 8])
	    	plt.imshow(image_np)
	    	plt.show()

将原计算图中,名称为: 'image_tensor:0'、 'detection_boxes:0'、 'detection_scores:0'、 'detection_classes:0'、 'num_detections:0'  对应的张量取出来。
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
        detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')
        
        for image_path in TEST_IMAGE_PATHS:
            image_np = load_image_into_numpy_array(image_path)
            image_np_expanded = np.expand_dims(image_np, axis=0)
            (boxes, scores, classes, num) = sess.run(
                [detection_boxes, detection_scores, detection_classes, num_detections], 
                feed_dict={image_tensor: image_np_expanded})

# visualize_boxes_and_labels_on_image_array() 
# 函数的主要作用是对原图像画上回归边框
# image     三维图片,矩阵格式
# boxes     回归框信息
# classes     分类信息
# scores     分类的分数信息
# category_index     字典,字典包括了('id','name')
            vis_util.visualize_boxes_and_labels_on_image_array(image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32),
                                        np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=2)
            print(np.squeeze(num))
            

            plt.figure(figsize=[12, 8])
            plt.imshow(image_np)
            plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值