一. Object Detection API的安装
参考链接:Installation
按以下命令依次安装:
git clone https://github.com/tensorflow/models.git
conda create -n tf=1.12 python=3.6
conda activate tf=1.12
# -------------------------------------------------
# For CPU
pip install tensorflow
# For GPU
pip install tensorflow-gpu==1.12
# -------------------------------------------------
sudo apt-get install protobuf-compiler python-pil python-lxml python-tk
pip install --user Cython
pip install --user contextlib2
pip install --user pillow
pip install --user lxml
pip install --user jupyter
pip install --user matplotlib
# -------------------------------------------------
# COCO API installation
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
make
cp -r pycocotools <path_to_tensorflow>/models/research/
# -------------------------------------------------
# From tensorflow/models/research/
protoc object_detection/protos/*.proto --python_out=.
# -------------------------------------------------
# Add Libraries to PYTHONPATH
sudo vim ~/.bashrc
export PYTHONPATH=$PYTHONPATH:/data2/zzw/Tensorflow-test/models/research/slim
source ~/.bashrc
# -------------------------------------------------
python setup.py build
python setup.py install
# -------------------------------------------------
# Testing the Installation
python object_detection/builders/model_builder_test.py
................
----------------------------------------------------------------------
Ran 16 tests in 0.112s
OK
二. Object Detection API的使用
1. 导出pb模型文件
cd models/research/object_detection/
export CONFIG_FILE=ssd_mobilenet_v1_coco_2018_01_28/pipeline.config
export CHECKPOINT_PATH=ssd_mobilenet_v1_coco_2018_01_28/model.ckpt
export OUTPUT_DIR=output
python object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path=$CONFIG_FILE \
--trained_checkpoint_prefix=$CHECKPOINT_PATH \
--output_directory=$OUTPUT_DIR \
--add_postprocessing_op=false
其中CONFIG_FILE为训练MSSD时候的配置文件,CHECKPOINT_PATH为训练产生的中间ckpt文件,OUTPUT_DIR为导出的pb文件所在的文件夹目录,add_postprocessing_op这里需要设置成false(MNN中不支持Postprocessing的处理,我们会在MNN做完一次前向传播后,在CPU端去做Postprocessing处理,实际上就是decoding和NMS)。
我们可以修改pipeline.config中的num_classes和height, width参数:
model {
ssd {
num_classes: 90
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
最终,在output文件夹中生成tflite_graph.pb和tflite_graph.pbtxt.
2. 网络裁剪
首先需要源码编译tensorflow=1.12.0,然后
cd tensorflow
export OUTPUT_DIR=output
bazel run --config=opt tensorflow/lite/toco:toco -- \
--input_file=$OUTPUT_DIR/tflite_graph.pb \
--output_file=$OUTPUT_DIR/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='concat','concat_1' \
--inference_type=FLOAT \
--change_concat_input_ranges=false
Tensorflow的Object detection api导出PostProcessing之前的Node,即concat和concat_1,后处理可以从device取出结果后,在cpu端进行postprocessing。
最终,在output文件夹中生成detect.tflite,可移植至安卓客户端。
参考链接:使用TensorFlow Lite将ssd_mobilenet移植至安卓客户端
Exporting a trained model for inference
Running on mobile with TensorFlow Lite