基于Person-Detection-and-Tracking项目的牛津宠物数据集分布式训练指南
项目概述
Person-Detection-and-Tracking项目是一个基于TensorFlow Object Detection API开发的目标检测与跟踪系统。本教程将指导您如何使用该项目在牛津-IIIT宠物数据集上训练一个能够检测不同品种猫狗的目标检测模型。
环境准备
Google Cloud平台配置
- 创建GCP项目:首先需要在Google Cloud平台上创建新项目
- 安装Cloud SDK:在本地工作站安装Google Cloud SDK工具包
- 启用ML Engine API:新项目默认不启用机器学习API,需手动开启
- 设置存储桶:创建Google Cloud Storage存储桶用于存放训练数据
完成上述步骤后,建议设置环境变量方便后续操作:
export YOUR_GCS_BUCKET=您的存储桶名称
数据集准备
获取牛津-IIIT宠物数据集
-
下载数据集压缩包:
- 图像数据集:images.tar.gz
- 标注数据:annotations.tar.gz
-
解压数据集:
tar -xvf images.tar.gz
tar -xvf annotations.tar.gz
转换为TFRecord格式
使用项目提供的脚本将原始数据转换为TensorFlow所需的TFRecord格式:
python object_detection/dataset_tools/create_pet_tf_record.py \
--label_map_path=object_detection/data/pet_label_map.pbtxt \
--data_dir=`pwd` \
--output_dir=`pwd`
转换完成后将生成:
- pet_train.record(训练集)
- pet_val.record(验证集)
上传至云存储
将转换后的数据上传至GCS存储桶:
gsutil cp pet_train.record gs://${YOUR_GCS_BUCKET}/data/pet_train.record
gsutil cp pet_val.record gs://${YOUR_GCS_BUCKET}/data/pet_val.record
gsutil cp object_detection/data/pet_label_map.pbtxt gs://${YOUR_GCS_BUCKET}/data/pet_label_map.pbtxt
模型准备
下载预训练模型
为加速训练,我们使用在COCO数据集上预训练的Faster R-CNN模型进行迁移学习:
- 下载预训练模型
- 解压并上传模型检查点文件至GCS
gsutil cp faster_rcnn_resnet101_coco_11_06_2017/model.ckpt.* gs://${YOUR_GCS_BUCKET}/data/
配置训练管道
项目使用配置文件定义模型和训练参数:
- 修改模板配置文件
faster_rcnn_resnet101_pets.config
- 替换文件中的
PATH_TO_BE_CONFIGURED
为实际GCS路径 - 上传配置文件至GCS
sed -i "s|PATH_TO_BE_CONFIGURED|"gs://${YOUR_GCS_BUCKET}"/data|g" \
object_detection/samples/configs/faster_rcnn_resnet101_pets.config
启动分布式训练
打包项目代码
python setup.py sdist
(cd slim && python setup.py sdist)
提交训练任务
配置10个训练节点(1主节点+9工作节点)和3个参数服务器:
gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` \
--runtime-version 1.2 \
--job-dir=gs://${YOUR_GCS_BUCKET}/train \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz \
--module-name object_detection.train \
--region us-central1 \
--config object_detection/samples/cloud/cloud.yml \
-- \
--train_dir=gs://${YOUR_GCS_BUCKET}/train \
--pipeline_config_path=gs://${YOUR_GCS_BUCKET}/data/faster_rcnn_resnet101_pets.config
启动并行评估任务
gcloud ml-engine jobs submit training `whoami`_object_detection_eval_`date +%s` \
--runtime-version 1.2 \
--job-dir=gs://${YOUR_GCS_BUCKET}/train \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz \
--module-name object_detection.eval \
--region us-central1 \
--scale-tier BASIC_GPU \
-- \
--checkpoint_dir=gs://${YOUR_GCS_BUCKET}/train \
--eval_dir=gs://${YOUR_GCS_BUCKET}/eval \
--pipeline_config_path=gs://${YOUR_GCS_BUCKET}/data/faster_rcnn_resnet101_pets.config
训练监控
使用TensorBoard监控训练进度:
tensorboard --logdir=gs://${YOUR_GCS_BUCKET}
访问localhost:6006
可查看:
- 训练损失曲线
- 评估指标
- 模型预测示例图像
模型导出
训练完成后,导出TensorFlow计算图:
- 从GCS下载模型检查点文件
- 运行导出脚本
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path object_detection/samples/configs/faster_rcnn_resnet101_pets.config \
--trained_checkpoint_prefix model.ckpt-${CHECKPOINT_NUMBER} \
--output_directory exported_graphs
导出的模型可用于后续推理任务。
进阶应用
完成基础目标检测训练后,您还可以:
- 尝试实例分割配置(添加mask预测)
- 使用自定义数据集训练
- 调整模型超参数进行优化
- 将模型部署到生产环境
本教程提供了使用Person-Detection-and-Tracking项目进行分布式训练的基础流程,开发者可根据实际需求进行调整和扩展。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考