TensorFlow 训练自己的目标检测器
本文主要描述如何使用 Google 开源的目标检测 API 来训练目标检测器,内容包括:安装 TensorFlow/Object Detection API 和使用 TensorFlow/Object Detection API 训练自己的目标检测器。
一、安装 TensorFlow Object Detection API
Google 开源的目标检测项目 object_detection 位于与 tensorflow 独立的项目 models(独立指的是:在安装 tensorflow 的时候并没有安装 models 部分)内:models/research/object_detection。models 部分的 GitHub 主页为:
https://github.com/tensorflow/models
要使用 models 部分内的目标检测功能 object_detection,需要用户手动安装 object_detection。下面为详细的安装步骤:
1. 安装依赖项 matplotlib,pillow,lxml 等
使用 pip/pip3 直接安装:
$ sudo pip/pip3 install matplotlib pillow lxml
其中如果安装 lxml 不成功,可使用
$ sudo apt-get install python-lxml python3-lxml
安装。
2. 安装编译工具
$ sudo apt install protobuf-compiler
$ sudo apt-get install python-tk
$ sudo apt-get install python3-tk
3. 克隆 TensorFlow models 项目
使用 git 克隆 models 部分到本地,在终端输入指令:
$ git clone https://github.com/tensorflow/models.git
克隆完成后,会在终端当前目录出现 models 的文件夹。要使用 git(分布式版本控制系统),首先得安装 git:$ sudo apt-get install git
。
4. 使用 protoc 编译
在 models/research 目录下的终端执行:
$ protoc object_detection/protos/*.proto --python_out=.
将 object_detection/protos/ 文件下的以 .proto 为后缀的文件编译为 .py 文件输出。
5. 配置环境变量
在 .bashrc 文件中加入环境变量。首先打开 .bashrc 文件:
$ sudo gedit ~/.bashrc
然后在文件末尾加入新行:
export PYTHONPATH=$PYTHONPATH:/.../models/research:/.../modes/research/slim
其中省略号所在的两个目录需要填写为 models/research 文件夹、models/research/slim 文件夹的完整目录。保存之后执行如下指令:
$ source ~/.bashrc
让改动立即生效。
6. 测试是否安装成功
在 models/research 文件下执行:
$ python/python3 object_detection/builders/model_builder_test.py
如果返回 OK,表示安装成功。
二、训练 TensorFlow 目标检测器
成功安装好 TensorFlow Object Detection API 之后,就可以按照 models/research/object_detection 文件夹下的演示文件 object_detection_tutorial.ipynb 来查看 Google 自带的目标检测的检测效果。其中,Google 自己训练好后的目标检测器都放在:
可以自己下载这些模型,一一查看检测效果。以下,假设你把某些预训练模型下载好了,放在models/ research/ object_detection 的某个文件夹下,比如自定义文件夹 pretrained_models。
要训练自己的模型,除了使用 Google 自带的预训练模型之外,最关键的是需要准备自己的训练数据。
以下,详细列出训练过程(后续部分文章将详细介绍一些目标检测算法):
1. 准备标注工具和文件格式转化工具
图像标注可以使用标注工具 labelImg,直接使用
$ sudo pip install labelImg
安装(当前好像只支持Python2.7)。另外,在此之前,需要安装它的依赖项 pyqt4:
$ sudo apt-get install pyqt4-dev-tools
(另一依赖项 lxml 前面已安装)。要使用 labelImg,只需要在终端输入 labelImg 即可。
为了方便后续数据格式转化,还需要准备两个文件格式转化工具:xml_to_csv.py 和 generate_tfrecord.py,它们的代码分别列举如下(它们可以从资料 [1] 中 GitHub 项目源代码链接中下载。其中为了方便一般化使用,我已经修改 generate_tfrecord.py 的部分内容使得可以自定义图像路径和输入 .csv 文件、输出 .record 文件路径,以及 6 中的 xxx_label_map.pbtxt 文件路径):
(1) xml_to_csv.py 文件源码:
import os