SSD-Tensorflow 学习一:部署

本文详细介绍如何使用SSD-TensorFlow进行目标检测,包括代码克隆、预训练模型加载、图像预处理、网络定义及预测流程。通过具体示例展示如何处理图像并识别对象,适合初学者快速上手。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.clone代码:

https://github.com/balancap/SSD-Tensorflow

2.提取/home/whut/yyCode/SSD-Tensorflow/checkpoints 下的ssd_300_vgg.ckpt.zip ,并删除文件夹,只保留里面的内容,避免报错:

Could not open ../checkpoints/ssd_300_vgg.ckpt: Failed precondition: ../checkpoints/ssd_300_vgg.ckpt: perhaps your file is in a different file format and you need to use a different restore operator?

3. 可以将.ipynb 转换为.py

# encoding: utf-8
import os
import math
import random
 
import numpy as np
import tensorflow as tf
import cv2
 
slim = tf.contrib.slim
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import sys
sys.path.append('../')
from nets import ssd_vgg_300, ssd_common, np_methods
from preprocessing import ssd_vgg_preprocessing
from notebooks import visualization
# TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!!
gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)
# Input placeholder.
net_shape = (300, 300)
data_format = 'NHWC'
img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
# Evaluation pre-processing: resize to SSD net shape.
image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
    img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
image_4d = tf.expand_dims(image_pre, 0)
 
# Define the SSD model.
reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
    predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
 
# Restore SSD model.
ckpt_filename = '/home/whut/yyCode/SSD-Tensorflow/checkpoints/ssd_300_vgg.ckpt'
# ckpt_filename = '../checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt'
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)
 
# SSD default anchor boxes.
ssd_anchors = ssd_net.anchors(net_shape)
 
 
# Main image processing routine.
def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):
    # Run SSD network.
    rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
                                                              feed_dict={img_input: img})
 
    # Get classes and bboxes from the net outputs.
    rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
        rpredictions, rlocalisations, ssd_anchors,
        select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
 
    rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
    rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
    rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
    # Resize bboxes to original image shape. Note: useless for Resize.WARP!
    rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
    return rclasses, rscores, rbboxes
# Test on some demo image and visualize output.
#测试的文件夹
path = '/home/whut/yyCode/SSD-Tensorflow/demo/'
image_names = sorted(os.listdir(path))
#文件夹中的第几张图,-1代表最后一张
img = mpimg.imread(path + image_names[-1])
rclasses, rscores, rbboxes =  process_image(img)
 
# visualization.bboxes_draw_on_img(img, rclasses, rscores, rbboxes, visualization.colors_plasma)
visualization.plt_bboxes(img, rclasses, rscores, rbboxes)

为了避免出错,建议将相对路径改为绝对路径。

4. 直接运行上述demo.py。

环境说明:

python3+

tensorflow-gpu

opencv

matplotlib

ipykernel:如果向通过jupyter notebook运行

pillow:

如果报错的话,即未安装Pillow库导致不能加载更多格式的图片,如下:

Traceback (most recent call last):
  File "demo.py", line 70, in <module>
    img = mpimg.imread(path + image_names[-1])
  File "/home/whut/anaconda2/envs/env_7/lib/python3.6/site-packages/matplotlib/image.py", line 1282, in imread
    'more images' % list(six.iterkeys(handlers)))
ValueError: Only know how to handle extensions: ['png']; with Pillow installed matplotlib can handle more images

当然可能存在一定的安装顺序问题。

建议可以直接用conda进行安装,需要啥安装啥。

 参考:

https://www.cnblogs.com/guohaoblog/p/9797064.html

https://blog.youkuaiyun.com/seven_year_promise/article/details/79306717

https://blog.youkuaiyun.com/w5688414/article/details/78286726

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值