感谢这位作者,以下记录是来自于https://blog.youkuaiyun.com/wangjian1204/article/details/79124018
我看到比较好,就转记录到自己的博客了,如果有侵权,立马删掉
图像分类和目标检测是计算机视觉两大模块。相比于图像分类,目标检测任务更复杂更困难。目标检测不但要检测到具体的目标,还要定位目标的具体位置。不过Tensorflow models上大神们的无私奉献已经使得目标检测模型平民化,只需要按照特定的格式准备好训练数据,就可以轻松训练出自己想要的目标检测模型。本文通过一个例子介绍如何通过Tensorflow models快速构建目标检测模型。
准备工作:
-
从github上clone最新的代码到本地:https://github.com/tensorflow/models
-
我们会用到models/research/object_detection/这个project的代码来训练和测试目标检测模型;models里面的项目很多都用protobuf来配置,所以需要在research目录下进行protobuf编译。
-
具体命令参考:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
-
下载训练图片和object标记数据并解压:
wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz tar -xvf images.tar.gz tar -xvf annotations.tar.gz
这个数据集包含37个种类的猫和狗。
-
在训练的机器上安装tensorflow 1.4版本(具体安装方法参考 Tensorflow 官网),在tf1.4新增了目标检测非极大值抑制,结果展示能模块,模型的训练在tf1.2及以上版本都可以正常执行;
目标检测模型训练:
目标检测模型的训练很简单,只需要做一些配置工作告诉代码去哪里读取数据,结果保存到哪里就可以了。
1. 首先对数据做一遍处理,把训练数据处理成tfrecords的格式:
python object_detection/dataset_tools/create_pet_tf_record.py \
--label_map_path=object_detection/data/pet_label_map.pbtxt \ #object_detection/data目录下已经有这个文件了,是训练数据类别的描述文件
--data_dir=`you_path/data` \ # 存放训练数据images和annotations的文件
--output_dir=`you_path/data` # tfrecords保存路径(文件夹)
- 这行代码是运行model/research/object_detection下的数据处理代码把你下载的images和annotations处理后的数据保存到tfrecords文件。运行结束后,你会发现you_path/data目录下多了两个文件:pet_train_with_masks.record、pet_val_with_masks.record。这两个文件是用来进行模型训练和验证的。
2. 下载预训练好的模型:
COCO-pretrained Faster R-CNN with Resnet-101 model:
wget http://storage.googleapis.com/download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_2017_11_08.tar.gz
- 新的模型不会从0开始训练,而是在这个模型的基础上进行调整,也就是一种迁移学习的方法。
3. 配置文件路径:
在object_detection/samples/configs目录下编辑faster_rcnn_resnet101_pets.config,只要修改下面几行,改成我们对应的文件路径即可:
(1)
fine_tune_checkpoint: "/data/pets/faster_rcnn_resnet101_coco_2017_11_08/model.ckpt" # 上面第二步下载的预训练模型
(2)
train_input_reader: {
tf_record_input_reader {
input_path: "/data/pets/pet_train_with_masks.record" #预处理数据生成的tfrecords文件
}
label_map_path: tensorflow/models/research/object_detection/data/pet_label_map.pbtxt" #这个文件在预处理数据时也用到过,用来说明训练样本的类别信息
}
(3)
eval_input_reader: {
tf_record_input_reader {
input_path: "/data/pets/pet_val_with_masks.record" #预处理数据生成的tfrecords文件
}
label_map_path: "/home/recsys/hzwangjian1/tensorflow/models/research/object_detection/data/pet_label_map.pbtxt" #这个文件和(2)中是同一个文件,用来说明训练样本的类别信息
shuffle: false
num_readers: 1
}
4. 开始训练模型:
python3 models/research/object_detection/train.py --logtostderr --train_dir=data/ --pipeline_config_path=object_detection/samples/configs/faster_rcnn_resnet101_pets.config
–train_dir:模型的checkpoint和summary都会保存在这个路径下
–pipeline_config_path:上一步的配置文件
5. 模型导出和预测新的图片:
首先把图导出到一个.pb文件,然后可以直接加载.pb文件恢复整个模型。
python3 models/research/object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path object_detection/samples/configs/faster_rcnn_resnet101_pets.config \
--trained_checkpoint_prefix data/model.ckpt-15000 \
--output_directory object_detection_graph
–pipeline_config_path:上一步训练时使用的配置文件,也可以与训练时使用的配置文件不同。
–trained_checkpoint_prefix:模型训练的checkpoint
–output_directory:图导出路径
执行代码后会生成.pb文件,修改运行models/research/object_detection/object_detection_tutorial.ipynb这个文件就可以来进行预测测试了。
参考资料:
1、https://www.oreilly.com/ideas/object-detection-with-tensorflow
2、https://github.com/tensorflow/models/tree/master/research/object_detection
3、https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pets.md