Tensorflow object detection API训练自己的数据

本文详细介绍了如何使用Tensorflow Object Detection API进行自定义数据的训练,包括安装、数据准备、模型配置、训练及结果测试。首先,需要安装TF Object Detection API及相关依赖,并确保版本匹配。接着,利用labelImg工具进行数据标注并转换为tfrecord格式。然后,选择预训练模型如ssd_mobilenet_v1_coco,并调整config文件以适应自己的数据。训练过程中,使用tensorboard监控损失曲线。最后,导出模型并进行效果测试。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一. 安装

    Tensorflow object detection api是tensorflow官方出品的检测工具包,集成了像ssd、faster rcnn等检测算法,mobilenet、inception、resnet等backbone和fpn、ppn等方法,各模块之间能够通过组合的方式来work。

    Github下载地址:https://github.com/tensorflow/models

    解压models-master,主要的内容都在/research/目录下,里面有很多代码包括ocr、nlp、speech等,本节我们只需要关注object_detection,也就是目标检测部分。

    代码安装步骤如下:

    1)确保电脑上已安装了tensorflow环境+keras,建议tf版本>=1.6,注意tf版本要与CUDA、cudnn环境匹配,否则可能会有意想不到的错误。ps:tensorflow版本升级过快,前向兼容性不好(接口不一致),确实得吐槽。

    2)编译protoc (建议版本3.4+)
         cd research
         protoc --python_out=. object_detection/protos/*.proto

    3)安装research & slim
         python setup.py install
         cd slim
         python setup.py install

    4)添加系统路径(research目录)
         export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
         也可以添加到系统变量:
         vim ~/.bashrc
         export PYTHONPATH=$PYTHONPATH:/path/to/research:/path/to/research/slim
         source ~/.bashrc

    5)测试是否安装成功(research目录)
         python object_detection/builders/model_builder_test.py
         安装成功会显示ok。

二. 数据准备

    数据标注工具建议使用labelImg,采用xml进行数据保存,相关教程比较多,这里不再赘述。

    这里需要说明的是:tensorflow需要tfrecord格式数据,需要在完成的标注数据基础上进行数据转换,这里将标注数据分为两个文件夹,train和test,文件夹下包含图片文件和xml。

    需要通过脚本进行数据转换:首先将annotation转换成csv,然后将csv转换成tfrecord。

       python xml_to_csv.py /path/to/train ./data/train_labels.csv
       python xml_to_csv.py /path/to/test ./data/test_labels.csv
       
       python generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train.record
       python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=test.record

# xml_to_csv.py
# -*- coding: utf-8 -*-

import os, sys
import glob
import pandas as pd
import xml.etree.ElementTree as ET

def xml_to_csv(_path, _out_file):
    xml_list = []
    for xml_file in glob.glob(_path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
              
TensorFlow Object Detection API 是一个开源项目,它提供了一系列基于 TensorFlow 的工具和库,用于实现目标检测任务。对于 macOS 系统,我们可以通过以下步骤来使用 TensorFlow Object Detection API: 1. 安装 TensorFlow:在 macOS 上安装 TensorFlow 是使用 TensorFlow Object Detection API 的前提。你可以通过 pip 命令进行安装,例如在终端中执行 `pip install tensorflow`。 2. 下载 TensorFlow Object Detection API:打开终端并导航到适合你的工作目录中,然后使用 git 命令来克隆 TensorFlow Object Detection API 的 GitHub 仓库,例如执行 `git clone https://github.com/tensorflow/models.git`。 3. 安装依赖项:进入克隆的模型目录中,找到 research 文件夹并进入。然后运行 `pip install -r object_detection/requirements.txt` 命令来安装所需的依赖项。 4. 下载预训练模型:在 TensorFlow Object Detection API 中,我们可以使用预训练的模型来进行目标检测。你可以从 TensorFlow Model Zoo 中下载适合你任务的模型,并将其解压到你的工作目录中。 5. 运行实例代码:在 research/object_detection 目录中,你可以找到一些示例代码,用于训练、评估和使用目标检测模型。可以通过阅读这些示例代码并根据自己的需求进行修改。例如,你可以使用 `python object_detection/builders/model_builder_tf2_test.py` 命令来运行一个模型的测试。 以上是在 macOS 上使用 TensorFlow Object Detection API 的基本步骤,你可以根据你的具体需求进行更多的深入研究和调整。希望这些信息能帮助到你!
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值