环境准备
运行这个预训练的模型需要准备一些环境
首先需要下载谷歌的models-master.zip
地址在https://github.com/Master-Chen/models
下载完成后我们需要的是research/objection_detection这个项目
在运行这个项目之前还需要下载谷歌的protoc3.4.0
下载结束后只需要将bin目录里的protoc.exe文件放在有环境变量的一个目录下即可
之后在research路径下打开命令行 运行 protoc objection_detection/protocs/*.proto --python_out=.
这里运行后会在object_detection\protos路径下生成许多py文件,相当于把原来的proto文件编译成了py文件
至此,环境准备基本完成。注意的是,这里使用的tensorflow1.13.1-cpu
运行模型
准备工作完成后,在objection_detection路径下启动jupyter notebook,找到
进入这个笔记本
可以看到,这个笔记本将引导使用者运行这个预训练的目标检测模型
- 导入相关模块
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops
# tf版本需要大于1.9
if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')
- 在jupyter中显示图片
# 在jupyter里面显示图片
%matplotlib inline
- 导入模块
from utils import label_map_util
from utils import visualization_utils as vis_util
- 指定模型的相关配置,譬如模型名称,下载地址,对应得pb文件存放路径,数据集label映射文件路径
这里使用的是SSD模型,在coco数据集上训练的,其他模型文件可以在github下载
# 模型名称
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
# 下载地址
DOWNLOAD_BASE