使用Object Detection API训练
准备工作
Running Locally中提到准备工作大致有三个:
- 安装Tensorflow Object Detection API
- 数据集
- Object Detection pipeline设置文件
安装Tensorflow Object Detection API
没啥说的,看官网教程Installation
数据集
可以按照Preparing Inputs来准备TFRecord格式的数据集。
当然也可以使用models/research/object_detection/dataset_tools/下的脚本将常见的数据集创建成TFRecord格式的。其中常用的有create_pascal_tf_record.py,就是将安装PASCAL VOC组织的数据转换为TFRecord格式。在使用时,代码其中的’aeroplane_’ + 多余,删去即可。
Object Detection pipeline设置文件
将object_detection/samples/configs/对应的config文件拷贝一份,然后根据实际情况修改。
- num_classes:修改为自己的classes num
- 将所有PATH_TO_BE_CONFIGURED的地方修改为自己之前设置的路径(共5处)
- batch_size根据情况修改,初始设置可能会导致内存不够用。
训练
legacy
train.py和eval.py被移到legacy文件下了。
python object_detection/legacy/train.py \
--logtostderr \
--pipeline_config_path=${PIPELINE_CONFIG_PATH} \
--train_dir=${TRAIN_DIR}
python object_detection/legacy/eval.py \
--logtostderr \
--checkpoint_dir=${TRAIN_DIR} \
--eval_dir=${EVAL_DIR} \
--pipeline_config_path=${PIPELINE_CONFIG_PATH}
如果报关于unicode的错,将object_detection\utils\object_detection_evaluation.py下的category_name = unicode(category_name, ‘utf-8’)修改为category_name = str(category_name)
recommend
model_main.py将train和eval结合在一块,官方推荐使用。
python object_detection/legacy/eval.py \
--logtostderr \
--checkpoint_dir=${TRAIN_DIR} \
--eval_dir=${EVAL_DIR} \
--pipeline_config_path=${PIPELINE_CONFIG_PATH}
- 添加 tf.logging.set_verbosity(tf.logging.INFO) 到model_main.py 的 import 区域之后,会每隔一百个step输出loss,总比没有好,至少它让你知道它在跑。
- 如果是python3训练,添加list() 到 model_lib.py的大概390行 category_index.values()变成: list(category_index.values()),否则会有 can’t pickle dict_values ERROR出现
- 还有一个问题是,用model_main.py 训练时,因为它把老版本的train.py和eval.py集合到了一起,所以制定eval num时指定不好会有warning出现,就像:
导出模型
export INPUT_TYPE=image_tensor
python object_detection/export_inference_graph.py \
--input_type=${INPUT_TYPE} \
--pipeline_config_path=${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix=${TRAIN_DIR}/`head -n 1 ${TRAIN_DIR}/checkpoint | grep -o -E '\".+\"' | sed s/\"//g` \
--output_directory=${EXPORT_DIR}
参考
TensorFlow object detection API应用
第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)