使用TensorFlow-Object-Detection-API 训练ssd_mobilenet

1. 数据集准备

        数据的准备需要自己去网上找或者去某宝进行购买,自己找的话可以去UCI:Home - UCI Machine Learning Repository,kaggle:Kaggle: Your Home for Data Science以及一些其他数据集网站进行查找。

        当数据集寻找完毕后使用labellmg进行数据的标注,,软件链接https://pan.baidu.com/s/1GNGdTYqlWoqqlmmvfLG7Gw?pwd=593v,进入软件的data目录在predefined_classes.txt下写入你需要打标签的类。

        打开软件,界面如下

        点击open dir打开图像文件夹

        点击change save dir 选择标签文件保存目录

        输入W进行画框

        A 和 D是上一张和下一张配合W进行愉快的打标之旅。

        当打完之后会在标签保存文件夹下看到很多xml的标签文件,这是VOC格式的YOLO的标签文件为txt文件。

        然后是数据增强,如果担心自己的数据集数量不够,使用数据增强进行扩充数据集。打开这个大佬的工程。可以选择下载ZIP包或者使用 git 进行克隆。bubbliiiing/object-detection-augmentation: 这里面存放了一些目标检测算法的数据增强方法。如mosaic、mixup。

        接着将将图片和标注文件分别放入VOCdevkit_Origin\VOC2007\JPEGImages  和VOCdevkit_Origin\VOC2007\Annotations文件中。

        接着选择三种增强方式进行增强。

        在文件中唯一需要改的就是out_num代表你需要增强多少张图片,图中代表最后总共会生成500张增强的图片以及对应的标注文件。点击运行进行增强,代码用使用到OpenCV、numpy、xml、PIL等模块请自行下载对应的包。

        增强完后会在VOCdevkit文件下生成对应的文件。

        然后将增强完的数据集放入最后需要的VOC文件夹目录如下

VOCdevkit/
└── VOC2007/
    ├── Annotations/              # 存放每个图像对应的 XML 标注文件
    ├── ImageSets/                # 包含训练集、验证集、测试集的划分列表
    │   ├── Main/                # 分类和检测任务的划分文件(如 train.txt, val.txt, trainval.txt, test.txt)
    ├── JPEGImages/               # 原始图像文件(JPEG格式)

        以我为例,我最后得到了接近12000张数据

        接着将数据集和一些这些文件放在一起进行数据划分。

        update_xml_attributes.py内容如下

# -*- coding: utf-8 -*-

import os
import xmltodict
from xml.dom.minidom import parseString
import logging
from tqdm import tqdm

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 定义参数
annotation_dir = 'VOCdevkit/VOC2007/Annotations'  # 标注文件目录
new_path_prefix = 'VOCdevkit/VOC2007/JPEGImages/'  # 新的路径前缀
new_folder = 'images'  # 新的文件夹名称

def update_xml_attributes():
    """更新XML文件中的path、folder和filename属性"""
    # 获取所有XML文件
    xml_files = [f for f in os.listdir(annotation_dir) if f.lower().endswith('.xml')]
    logger.info(f"找到 {len(xml_files)} 个XML文件需要处理")
    
    # 处理每个XML文件
    for xml_file in tqdm(xml_files, desc="处理XML文件"):
        xml_path = os.path.join(annotation_dir, xml_file)
        
        try:
            # 读取XML文件
            with open(xml_path, 'r', encoding='utf-8') as f:
                xml_data = xmltodict.parse(f.read())
            
            # 获取对应的图片文件名(将.xml替换为.jpg)
            image_filename = os.path.splitext(xml_file)[0] + '.jpg'
            
            # 更新path、folder和filename属性
            xml_data['annotation']['folder'] = new_folder
            xml_data['annotation']['filename'] = image_filename
            xml_data['annotation']['path'] = os.path.join(new_path_prefix, image_filename)
            
            # 将更新后的数据写回文件
            xml_str = xmltodict.unparse(xml_data, pretty=True)
            dom = parseString(xml_str)
            pretty_xml = dom.toprettyxml(indent='  ')
            
            with open(xml_path, 'w', encoding='utf-8') as f:
                f.write(pretty_xml)
                
        except Exception as e:
            logger.error(f"处理文件 {xml_file} 时出错: {str(e)}")
    
    logger.info(f"成功更新 {len(xml_files)} 个XML文件的属性")

if __name__ == "__main__":
    update_xml_attributes()

        utils.py内容如下

import numpy as np
from PIL import Image


#---------------------------------------------------------#
#   将图像转换成RGB图像,防止灰度图在预测时报错。
#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
        return image 
    else:
        image = image.convert('RGB')
        return image 

#---------------------------------------------------#
#   对输入图像进行resize
#---------------------------------------------------#
def resize_image(image, size, letterbox_image):
    iw, ih  = image.size
    w, h    = size
    if letterbox_image:
        scale   = min(w/iw, h/ih)
        nw      = int(iw*scale)
        nh      = int(ih*scale)

        image   = image.resize((nw,nh), Image.BICUBIC)
        new_image = Image.new('RGB', size, (128,128,128))
        new_image.paste(image, ((w-nw)//2, (h-nh)//2))
    else:
        new_image = image.resize((w, h), Image.BICUBIC)
    return new_image

#---------------------------------------------------#
#   获得类
#---------------------------------------------------#
def get_classes(classes_path):
    with open(classes_path, encoding='utf-8') as f:
        class_names = f.readlines()
    class_names = [c.strip() for c in class_names]
    return class_names, len(class_names)

def show_config(**kwargs):
    print('Configurations:')
    print('-' * 70)
    print('|%25s | %40s|' % ('keys', 'values'))
    print('-' * 70)
    for key, value in kwargs.items():
        print('|%25s | %40s|' % (str(key), str(value)))
    print('-' * 70)
    
#-------------------------------------------------------------------------------------------------------------------------------#
#   From https://github.com/ckyrkou/Keras_FLOP_Estimator 
#   Fix lots of bugs
#-------------------------------------------------------------------------------------------------------------------------------#
def net_flops(model, table=False, print_result=True):
    if (table == True):
        print("\n")
        print('%25s | %16s | %16s | %16s | %16s | %6s | %6s' % (
            'Layer Name', 'Input Shape', 'Output Shape', 'Kernel Size', 'Filters', 'Strides', 'FLOPS'))
        print('=' * 120)
        
    #---------------------------------------------------#
    #   总的FLOPs
    #---------------------------------------------------#
    t_flops = 0
    factor  = 1e9

    for l in model.layers:
        try:
            #--------------------------------------#
            #   所需参数的初始化定义
            #--------------------------------------#
            o_shape, i_shape, strides, ks, filters = ('', '', ''), ('', '', ''), (1, 1), (0, 0), 0
            flops   = 0
            #--------------------------------------#
            #   获得层的名字
            #--------------------------------------#
            name    = l.name
            
            if ('InputLayer' in str(l)):
                i_shape = l.get_input_shape_at(0)[1:4]
                o_shape = l.get_output_shape_at(0)[1:4]
                
            #--------------------------------------#
            #   Reshape层
            #--------------------------------------#
            elif ('Reshape' in str(l)):
                i_shape = l.get_input_shape_at(0)[1:4]
                o_shape = l.get_output_shape_at(0)[1:4]

            #--------------------------------------#
            #   填充层
            #--------------------------------------#
            elif ('Padding' in str(l)):
                i_shape = l.get_input_shape_at(0)[1:4]
                o_shape = l.get_output_shape_at(0)[1:4]

            #--------------------------------------#
            #   平铺层
            #--------------------------------------#
            elif ('Flatten' in str(l)):
                i_shape = l.get_input_shape_at(0)[1:4]
                o_shape = l.get_output_shape_at(0)[1:4]
                
            #--------------------------------------#
            #   激活函数层
            #--------------------------------------#
            elif 'Activation' in str(l):
                i_shape = l.get_input_shape_at(0)[1:4]
                o_shape = l.get_output_shape_at(0)[1:4]
                
            #--------------------------------------#
            #   LeakyReLU
            #--------------------------------------#
            elif 'LeakyReLU' in str(l):
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]
                    
                    flops   += i_shape[0] * i_shape[1] * i_shape[2]
                    
            #--------------------------------------#
            #   池化层
            #--------------------------------------#
            elif 'MaxPooling' in str(l):
                i_shape = l.get_input_shape_at(0)[1:4]
                o_shape = l.get_output_shape_at(0)[1:4]
                    
            #--------------------------------------#
            #   池化层
            #--------------------------------------#
            elif ('AveragePooling' in str(l) and 'Global' not in str(l)):
                strides = l.strides
                ks      = l.pool_size
                
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]
                    
                    flops   += o_shape[0] * o_shape[1] * o_shape[2]

            #--------------------------------------#
            #   全局池化层
            #--------------------------------------#
            elif ('AveragePooling' in str(l) and 'Global' in str(l)):
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]
                    
                    flops += (i_shape[0] * i_shape[1] + 1) * i_shape[2]
                
            #--------------------------------------#
            #   标准化层
            #--------------------------------------#
            elif ('BatchNormalization' in str(l)):
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]

                    temp_flops = 1
                    for i in range(len(i_shape)):
                        temp_flops *= i_shape[i]
                    temp_flops *= 2
                    
                    flops += temp_flops
                
            #--------------------------------------#
            #   全连接层
            #--------------------------------------#
            elif ('Dense' in str(l)):
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]
                
                    temp_flops = 1
                    for i in range(len(o_shape)):
                        temp_flops *= o_shape[i]
                        
                    if (i_shape[-1] == None):
                        temp_flops = temp_flops * o_shape[-1]
                    else:
                        temp_flops = temp_flops * i_shape[-1]
                    flops += temp_flops

            #--------------------------------------#
            #   普通卷积层
            #--------------------------------------#
            elif ('Conv2D' in str(l) and 'DepthwiseConv2D' not in str(l) and 'SeparableConv2D' not in str(l)):
                strides = l.strides
                ks      = l.kernel_size
                filters = l.filters
                bias    = 1 if l.use_bias else 0
                
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]
                    
                    if (filters == None):
                        filters = i_shape[2]
                    flops += filters * o_shape[0] * o_shape[1] * (ks[0] * ks[1] * i_shape[2] + bias)

            #--------------------------------------#
            #   逐层卷积层
            #--------------------------------------#
            elif ('Conv2D' in str(l) and 'DepthwiseConv2D' in str(l) and 'SeparableConv2D' not in str(l)):
                strides = l.strides
                ks      = l.kernel_size
                filters = l.filters
                bias    = 1 if l.use_bias else 0
            
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]
                    
                    if (filters == None):
                        filters = i_shape[2]
                    flops += filters * o_shape[0] * o_shape[1] * (ks[0] * ks[1] + bias)
                
            #--------------------------------------#
            #   深度可分离卷积层
            #--------------------------------------#
            elif ('Conv2D' in str(l) and 'DepthwiseConv2D' not in str(l) and 'SeparableConv2D' in str(l)):
                strides = l.strides
                ks      = l.kernel_size
                filters = l.filters
                
                for i in range(len(l._inbound_nodes)):
                    i_shape = l.get_input_shape_at(i)[1:4]
                    o_shape = l.get_output_shape_at(i)[1:4]
                    
                    if (filters == None):
                        filters = i_shape[2]
                    flops += i_shape[2] * o_shape[0] * o_shape[1] * (ks[0] * ks[1] + bias) + \
                             filters * o_shape[0] * o_shape[1] * (1 * 1 * i_shape[2] + bias)
            #--------------------------------------#
            #   模型中有模型时
            #--------------------------------------#
            elif 'Model' in str(l):
                flops = net_flops(l, print_result=False)
                
            t_flops += flops

            if (table == True):
                print('%25s | %16s | %16s | %16s | %16s | %6s | %5.4f' % (
                    name[:25], str(i_shape), str(o_shape), str(ks), str(filters), str(strides), flops))
                
        except:
            pass
    
    t_flops = t_flops * 2
    if print_result:
        show_flops = t_flops / factor
        print('Total GFLOPs: %.3fG' % (show_flops))
    return t_flops

        voc_annotation.py 内容如下

import os
import random
import xml.etree.ElementTree as ET

import numpy as np

from utils import get_classes

#--------------------------------------------------------------------------------------------------------------------------------#
#   annotation_mode用于指定该文件运行时计算的内容
#   annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
#   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
#   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
annotation_mode     = 0
#-------------------------------------------------------------------#
#   必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
#   与训练和预测所用的classes_path一致即可
#   如果生成的2007_train.txt里面没有目标信息
#   那么就是因为classes没有设定正确
#   仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
classes_path        = 'model_data/voc_classes.txt'
#--------------------------------------------------------------------------------------------------------------------------------#
#   trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
#   train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
#   仅在annotation_mode为0和1的时候有效
#--------------------------------------------------------------------------------------------------------------------------------#
trainval_percent    = 0.7
train_percent       = 0.9
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path  = 'VOCdevkit'

VOCdevkit_sets  = [('2007', 'train'), ('2007', 'val')]
classes, _      = get_classes(classes_path)

#-------------------------------------------------------#
#   统计目标数量
#-------------------------------------------------------#
photo_nums  = np.zeros(len(VOCdevkit_sets))
nums        = np.zeros(len(classes))
def convert_annotation(year, image_id, list_file):
    in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
    tree=ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter('object'):
        difficult = 0 
        if obj.find('difficult')!=None:
            difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
        
        nums[classes.index(cls)] = nums[classes.index(cls)] + 1
        
if __name__ == "__main__":
    random.seed(0)
    if " " in os.path.abspath(VOCdevkit_path):
        raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")

    if annotation_mode == 0 or annotation_mode == 1:
        print("Generate txt in ImageSets.")
        xmlfilepath     = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
        saveBasePath    = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
        temp_xml        = os.listdir(xmlfilepath)
        total_xml       = []
        for xml in temp_xml:
            if xml.endswith(".xml"):
                total_xml.append(xml)

        num     = len(total_xml)  
        list    = range(num)  
        tv      = int(num*trainval_percent)  
        tr      = int(tv*train_percent)  
        trainval= random.sample(list,tv)  
        train   = random.sample(trainval,tr)  
        
        print("train and val size",tv)
        print("train size",tr)
        ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
        ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  
        ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  
        fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  
        
        for i in list:  
            name=total_xml[i][:-4]+'\n'  
            if i in trainval:  
                ftrainval.write(name)  
                if i in train:  
                    ftrain.write(name)  
                else:  
                    fval.write(name)  
            else:  
                ftest.write(name)  
        
        ftrainval.close()  
        ftrain.close()  
        fval.close()  
        ftest.close()
        print("Generate txt in ImageSets done.")

    if annotation_mode == 0 or annotation_mode == 2:
        print("Generate 2007_train.txt and 2007_val.txt for train.")
        type_index = 0
        for year, image_set in VOCdevkit_sets:
            image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
            list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
            for image_id in image_ids:
                list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))

                convert_annotation(year, image_id, list_file)
                list_file.write('\n')
            photo_nums[type_index] = len(image_ids)
            type_index += 1
            list_file.close()
        print("Generate 2007_train.txt and 2007_val.txt for train done.")
        
        def printTable(List1, List2):
            for i in range(len(List1[0])):
                print("|", end=' ')
                for j in range(len(List1)):
                    print(List1[j][i].rjust(int(List2[j])), end=' ')
                    print("|", end=' ')
                print()

        str_nums = [str(int(x)) for x in nums]
        tableData = [
            classes, str_nums
        ]
        colWidths = [0]*len(tableData)
        len1 = 0
        for i in range(len(tableData)):
            for j in range(len(tableData[i])):
                if len(tableData[i][j]) > colWidths[i]:
                    colWidths[i] = len(tableData[i][j])
        printTable(tableData, colWidths)

        if photo_nums[0] <= 500:
            print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")

        if np.sum(nums) == 0:
            print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
            print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
            print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
            print("(重要的事情说三遍)。")

        voc_classes.txt 这里面就是你数据集总共要识别多少类就填多少个类,我这里是识别6类。

        这里面也会用到一些模块,xml、numpy、xmltodict等注意提前装好,遇到缺那个模块就装那个模块。运行voc_annotation.py进行数据集划分,会有提示信息,在Main目录下也会生成几个文件。

        然后运行update_xml_attributes.py处理一下标注文件,因为刚开始打标签的路径会保留在标注文件中而现在我们的标注文件和图片的路径是变动了的。如果不改后面训练时会报错。

2. 租用训练服务器

        打开网站AutoDL算力云 | 弹性、好用、省钱。租GPU就上AutoDL进行租用训练机器,如果你的机器够强也可以不租用。推荐3090性价比高,然后地区选择离你所在城市近一点的,这样后面上传数据会快的多。

        选一个有空闲的机器进行租用

        然后选择配置,尽量和我一样吧。

        然后会有登录信息和密码

        打开MobaXterm粘如ssh然后更改端口和用户名

        接着删除多余信息,留下网址

        然后进行登录,粘如网页的密码        登录成功

        输入conda init,然后关闭窗口在重新连接

        然后去TensorFlow仓库吧models下载下来tensorflow/models: Models and examples built with TensorFlow

        接着将model和数据集一起传入服务器的/root/autodl-tmp目录下

        解压出来

3. 训练环境准备

1. 创建conda环境

conda create -n tfod python=3.8 -y

conda activate tfod

2.安装TensorFlow

pip install "tensorflow==2.10.0" 

3. 安装依赖库

sudo apt install -y protobuf-compiler

pip install matplotlib opencv-python pillow lxml Cython contextlib2 tf_slim

4. 编译Protobuf

cd models/research/

protoc object_detection/protos/*.proto --python_out=.

5. 安装OD API

cp object_detection/packages/tf2/setup.py .

python -m pip install .

6.创建标签映射文件 label_map.pbtxt,内容为如下内容,根据实际进行填写

item {

  id: 1

  name: "Blight"}

item {

  id: 2

  name: "Brownspot"}

item {

  id: 3

  name: "Insect"}

item {

  id: 4

  name: "Powderymildew"}

item {

  id: 5

  name: "Rust"}

item {

  id: 6

  name: "Sootymold"}

7.生成TFRecord文件,创建转换文件custom_voc_to_tfrecord.py放到research目录下

代码如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Custom VOC to TFRecord Converter for Agricultural Pest Dataset
"""
import os
import tensorflow as tf
from object_detection.utils import dataset_util
import xml.etree.ElementTree as ET

flags = tf.compat.v1.flags
flags.DEFINE_string('data_dir', '../../VOCdevkit', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', 'train.record', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'label_map.pbtxt', 'Path to label map proto')
flags.DEFINE_string('set', 'train', 'Convert training set or validation set')
FLAGS = flags.FLAGS

CLASS_NAMES = ['Blight', 'Brownspot', 'Insect', 'Powderymildew', 'Rust', 'Sootymold']

def get_class_id(name):
    return CLASS_NAMES.index(name) + 1  # Class IDs start from 1

def dict_to_tf_example(data, dataset_directory):
    img_path = os.path.join(dataset_directory, 'JPEGImages', data['filename'])
    with tf.io.gfile.GFile(img_path, 'rb') as fid:
        encoded_jpg = fid.read()
    
    width = int(data['size']['width'])
    height = int(data['size']['height'])

    xmins, ymins, xmaxs, ymaxs = [], [], [], []
    classes_text, classes = [], []

    for obj in data['objects']:
        xmins.append(float(obj['bndbox']['xmin']) / width)
        ymins.append(float(obj['bndbox']['ymin']) / height)
        xmaxs.append(float(obj['bndbox']['xmax']) / width)
        ymaxs.append(float(obj['bndbox']['ymax']) / height)
        classes_text.append(obj['name'].encode('utf8'))
        classes.append(get_class_id(obj['name']))

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(data['filename'].encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(data['filename'].encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
        'image/object/difficult': dataset_util.int64_list_feature([0]*len(classes)),
    }))
    return tf_example

def parse_xml(xml_path):
    with tf.io.gfile.GFile(xml_path, 'r') as fid:
        xml_str = fid.read()
    xml = ET.fromstring(xml_str)
    
    data = {
        'filename': xml.find('filename').text,
        'size': {
            'width': xml.find('size/width').text,
            'height': xml.find('size/height').text,
            'depth': xml.find('size/depth').text,
        },
        'objects': []
    }

    for obj in xml.iter('object'):
        obj_data = {
            'name': obj.find('name').text,
            'bndbox': {
                'xmin': obj.find('bndbox/xmin').text,
                'ymin': obj.find('bndbox/ymin').text,
                'xmax': obj.find('bndbox/xmax').text,
                'ymax': obj.find('bndbox/ymax').text,
            }
        }
        data['objects'].append(obj_data)
    
    return data

def main(_):
    writer = tf.io.TFRecordWriter(FLAGS.output_path)
    examples_path = os.path.join(FLAGS.data_dir, FLAGS.year, 'ImageSets', 'Main', FLAGS.set + '.txt')
    
    with open(examples_path, 'r') as f:
        lines = f.readlines()
    
    for idx, line in enumerate(lines):
        xml_path = os.path.join(FLAGS.data_dir, FLAGS.year, 'Annotations', line.strip() + '.xml')
        data = parse_xml(xml_path)
        tf_example = dict_to_tf_example(data, os.path.join(FLAGS.data_dir, FLAGS.year))
        writer.write(tf_example.SerializeToString())
        
        if idx % 100 == 0:
            print(f'Processed {idx}/{len(lines)} examples')

    writer.close()
    print(f'Successfully created TFRecord: {FLAGS.output_path}')

if __name__ == '__main__':
tf.compat.v1.app.run()

8.转换数据集

python custom_voc_to_tfrecord.py \

  --data_dir=../../VOCdevkit \

  --year=VOC2007 \

  --output_path=train.record \

  --label_map_path=label_map.pbtxt \

  --set=train

python custom_voc_to_tfrecord.py \

  --data_dir=../../VOCdevkit  \

  --year=VOC2007 \

  --output_path=val.record \

  --label_map_path=label_map.pbtxt \

  --set=val

看到输出信息代表转换成功

9.下载预训练模型并修改配置文件

wget  http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz

tar -xzvf ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz

修改下载文件中的配置文件 pipeline.config替换如下内容

model {

  ssd {

    num_classes: 6

    image_resizer {

      fixed_shape_resizer {

        height: 320

        width: 320

      }

    }

    feature_extractor {

      type: "ssd_mobilenet_v2_keras"

      depth_multiplier: 1.0

      min_depth: 16

      conv_hyperparams {

        regularizer {

          l2_regularizer {

            weight: 4e-05

          }

        }

        initializer {

          truncated_normal_initializer {

            mean: 0.0

            stddev: 0.03

          }

        }

        activation: RELU_6

        batch_norm {

          decay: 0.97

          center: true

          scale: true

          epsilon: 0.001

          train: true

        }

      }

      override_base_feature_extractor_hyperparams: true

    }

    box_coder {

      faster_rcnn_box_coder {

        y_scale: 10.0

        x_scale: 10.0

        height_scale: 5.0

        width_scale: 5.0

      }

    }

    matcher {

      argmax_matcher {

        matched_threshold: 0.5

        unmatched_threshold: 0.5

        ignore_thresholds: false

        negatives_lower_than_unmatched: true

        force_match_for_each_row: true

        use_matmul_gather: true

      }

    }

    similarity_calculator {

      iou_similarity {

      }

    }

    box_predictor {

      convolutional_box_predictor {

        conv_hyperparams {

          regularizer {

            l2_regularizer {

              weight: 4e-05

            }

          }

          initializer {

            random_normal_initializer {

              mean: 0.0

              stddev: 0.01

            }

          }

          activation: RELU_6

          batch_norm {

            decay: 0.97

            center: true

            scale: true

            epsilon: 0.001

            train: true

          }

        }

        min_depth: 0

        max_depth: 0

        num_layers_before_predictor: 0

        use_dropout: false

        dropout_keep_probability: 0.8

        kernel_size: 1

        box_code_size: 4

        apply_sigmoid_to_scores: false

        class_prediction_bias_init: -4.6

      }

    }

    anchor_generator {

      ssd_anchor_generator {

        num_layers: 6

        min_scale: 0.2

        max_scale: 0.95

        aspect_ratios: 1.0

        aspect_ratios: 2.0

        aspect_ratios: 0.5

        aspect_ratios: 3.0

        aspect_ratios: 0.3333

      }

    }

    post_processing {

      batch_non_max_suppression {

        score_threshold: 1e-08

        iou_threshold: 0.6

        max_detections_per_class: 100

        max_total_detections: 100

        use_static_shapes: false

      }

      score_converter: SIGMOID

    }

    normalize_loss_by_num_matches: true

    loss {

      localization_loss {

        weighted_smooth_l1 {

          delta: 1.0

        }

      }

      classification_loss {

        weighted_sigmoid_focal {

          gamma: 2.0

          alpha: 0.75

        }

      }

      classification_weight: 1.0

      localization_weight: 1.0

    }

    encode_background_as_zeros: true

    normalize_loc_loss_by_codesize: true

    inplace_batchnorm_update: true

    freeze_batchnorm: false

  }

}

train_config {

  batch_size: 48

  data_augmentation_options {

    random_horizontal_flip {}

  }

  data_augmentation_options {

    ssd_random_crop {}

  }

  sync_replicas: true

  optimizer {

    momentum_optimizer {

      learning_rate {

        cosine_decay_learning_rate {

          learning_rate_base: 0.04

          total_steps: 30000

          warmup_learning_rate: 0.01

          warmup_steps: 1500

        }

      }

      momentum_optimizer_value: 0.9

    }

    use_moving_average: false

  }

  fine_tune_checkpoint: "./ssd_mobilenet_v2_320x320_coco17_tpu-8/checkpoint/ckpt-0"

  num_steps: 30000

  startup_delay_steps: 0.0

  replicas_to_aggregate: 8

  max_number_of_boxes: 50

  unpad_groundtruth_tensors: false

  fine_tune_checkpoint_type: "detection"

  fine_tune_checkpoint_version: V2

}

train_input_reader {

  label_map_path: "./label_map.pbtxt"

  tf_record_input_reader {

    input_path: "train.record"

  }

}

eval_config {

  metrics_set: "coco_detection_metrics"

  use_moving_averages: false

  eval_interval_secs: 900

  max_evals: 50

}

eval_input_reader {

  label_map_path: "./label_map.pbtxt"

  shuffle: false

  num_epochs: 1

  tf_record_input_reader {

    input_path: "val.record"

  }

}

        最重要的就是num_classes:要识别的类别, fine_tune_checkpoint_type: "detection" 进行检测如果是classxx就是分类。

4. 模型训练

        装screen包进行后台训,就算你断开连接也会在服务器上进行训练。使用apt-get install screen命令安装。

screen -U -S name #新建窗口 名称为name

screen -S 窗口名称 -X quit

退出该窗口按Ctrl + a + d。

然后使用screen -ls可以查看

自动清理 screen -wipe

screen -r 会话ID

        新建一个训练窗口screen -U -S trian

        然后执行训练,这个过程大概会持续一个小时因为我batch_size设置成48,total_steps: 30000训练步长为30000

python object_detection/model_main_tf2.py \

  --pipeline_config_path=./ssd_mobilenet_v2_320x320_coco17_tpu-8/pipeline.config \

  --model_dir=training/ \

  --alsologtostderr \

  --checkpoint_every_n=1000

        按ctrl+a+d退出训练窗口,新建一个监视窗口screen -U -S jianshi,然后输入tensorboard --port 6007 --logdir ./training

        然后进入实例监控,tensorboard就能进行训练监控。如果没有就刷新一下网页。

5. 导出模型

        训练完后输入以下命令进行导出模型

python object_detection/exporter_main_v2.py \

    --trained_checkpoint_dir=training/ \

    --pipeline_config_path=ssd_mobilenet_v2_320x320_coco17_tpu-8/pipeline.config \

    --output_directory=exported_model

        导出tflite模型

python object_detection/export_tflite_graph_tf2.py \

  --pipeline_config_path=ssd_mobilenet_v2_320x320_coco17_tpu-8/pipeline.config \

  --trained_checkpoint_dir=training/ \

  --output_directory=exported_tflite_model \

  --max_detections=10

6.测试模型

        将导出的模型文件夹下载下来

        编写测试代码test.py,我TensorFlow版本为2.14。

import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import os

# 定义类别标签
LABELS = {
    1: "Blight",           # 疫病
    2: "Brownspot",        # 褐斑病
    3: "Insect",          # 虫害
    4: "Powderymildew",    # 白粉病
    5: "Rust",            # 锈病
    6: "Sootymold"        # 煤烟病
}

# 为不同病害类别定义颜色
COLORS = {
    1: (255, 0, 0),      # 红色 - 疫病
    2: (139, 69, 19),    # 棕色 - 褐斑病
    3: (0, 255, 0),      # 绿色 - 虫害
    4: (255, 255, 255),  # 白色 - 白粉病
    5: (255, 165, 0),    # 橙色 - 锈病
    6: (0, 255, 120)       # 黑色 - 煤烟病
}

def draw_chinese_text(img, text, position, font_size=20, text_color=(0, 0, 0), bg_color=None):
    """使用PIL绘制中文文本"""
    # 转换图像为PIL格式
    img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(img_pil)
    
    # 尝试加载系统中的中文字体
    try:
        # 尝试常见的中文字体路径
        font_paths = [
            "C:/Windows/Fonts/simsun.ttc",  # 宋体
            "C:/Windows/Fonts/simhei.ttf",   # 黑体
            "C:/Windows/Fonts/msyh.ttc",     # 微软雅黑
        ]
        font = None
        for path in font_paths:
            if os.path.exists(path):
                font = ImageFont.truetype(path, font_size)
                break
        if font is None:
            # 如果找不到中文字体,使用默认字体
            font = ImageFont.load_default()
    except Exception as e:
        print(f"加载字体失败: {e}")
        font = ImageFont.load_default()
    
    # 如果需要绘制背景
    if bg_color is not None:
        # 计算文本大小
        bbox = draw.textbbox(position, text, font=font)
        # 绘制背景矩形
        draw.rectangle([bbox[0]-5, bbox[1]-5, bbox[2]+5, bbox[3]+5], fill=bg_color)
    
    # 绘制文本
    draw.text(position, text, font=font, fill=text_color)
    
    # 转换回OpenCV格式
    return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)

try:
    # 加载 SavedModel
    saved_model_path = 'exported_model/saved_model'
    model = tf.saved_model.load(saved_model_path)
    infer = model.signatures['serving_default']
    
    # 加载和预处理图像
    image = cv2.imread('imge_17875.jpg')
    if image is None:
        raise FileNotFoundError("无法加载图像文件")
    
    original_image = image.copy()
    image_height, image_width = image.shape[:2]
    
    # 预处理图像
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (320, 320))
    input_tensor = tf.convert_to_tensor(image, dtype=tf.uint8)
    input_tensor = tf.expand_dims(input_tensor, 0)

    # 进行推理
    detections = infer(input_tensor=input_tensor)
    
    # 获取检测结果
    boxes = detections['detection_boxes'].numpy()[0]
    classes = detections['detection_classes'].numpy()[0]
    scores = detections['detection_scores'].numpy()[0]
    num_detections = int(detections['num_detections'].numpy()[0])
    
    # 设置置信度阈值
    confidence_threshold = 0.3

    # 统计每类置信度最高的两个目标
    detections_per_class = {}
    for i in range(num_detections):
        if scores[i] >= confidence_threshold:
            class_id = int(classes[i])
            if class_id not in detections_per_class:
                detections_per_class[class_id] = []
            detections_per_class[class_id].append({
                'score': scores[i],
                'box': boxes[i],
                'index': i
            })
    # 对每类按分数排序,只保留前2个
    for class_id in detections_per_class:
        detections_per_class[class_id].sort(key=lambda x: x['score'], reverse=True)
        detections_per_class[class_id] = detections_per_class[class_id][:2]

    # 汇总所有保留的检测目标
    selected_indices = []
    for dets in detections_per_class.values():
        for det in dets:
            selected_indices.append(det['index'])

    # 处理每个保留的检测结果
    result_image = original_image.copy()
    valid_detections = 0
    print("\n检测到的病害:")
    for i in selected_indices:
        ymin, xmin, ymax, xmax = boxes[i]
        x1 = int(xmin * image_width)
        y1 = int(ymin * image_height)
        x2 = int(xmax * image_width)
        y2 = int(ymax * image_height)
        class_id = int(classes[i])
        class_name = LABELS.get(class_id, f"Unknown_{class_id}")
        confidence = scores[i]
        color = COLORS.get(class_id, (0, 255, 255))
        chinese_names = {
            "Blight": "疫病",
            "Brownspot": "褐斑病",
            "Insect": "虫害",
            "Powderymildew": "白粉病",
            "Rust": "锈病",
            "Sootymold": "煤烟病"
        }
        chinese_name = chinese_names.get(class_name, class_name)
        label = f"{class_name}({chinese_name}): {confidence:.2f}"
        result_image = draw_chinese_text(
            result_image,
            label,
            (x1, max(0, y1 - 25)),
            font_size=20,
            text_color=(0, 0, 0),
            bg_color=color
        )
        cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
        valid_detections += 1
        print(f"{valid_detections}. {class_name}({chinese_name}): 置信度={confidence:.3f}, 位置={[x1, y1, x2, y2]}")
    
    # 显示结果
    plt.figure(figsize=(15, 8))
    
    # 显示原始图像
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
    plt.title('Original Image')
    plt.axis('off')
    
    # 显示检测结果
    plt.subplot(1, 2, 2)
    plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
    plt.title(f'Detection Results ({valid_detections} diseases detected)')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # 保存结果图像
    cv2.imwrite('detection_result.jpg', result_image)
    print(f"\nResults saved to: detection_result.jpg")

except Exception as e:
    print(f"\n错误: {str(e)}")
    import traceback
    traceback.print_exc()

        运行结果如下

        然后对模型进行int8量化,编写quantize_to_int8.py转换代码,代码如下

import tensorflow as tf
import numpy as np
import glob
import cv2
import os

# 路径配置
SAVED_MODEL_DIR = './exported_tflite_model/saved_model'
OUTPUT_TFLITE_PATH = 'model_int8.tflite'
CALIBRATION_IMAGES_GLOB = './VOCdevkit/VOC2007/JPEGImages/*.jpg'  # 你的图片路径
NUM_CALIBRATION_IMAGES = 500
INPUT_SIZE = (320, 320)

def representative_dataset():
    image_files = glob.glob(CALIBRATION_IMAGES_GLOB)
    for i, image_path in enumerate(image_files[:NUM_CALIBRATION_IMAGES]):
        img = cv2.imread(image_path)
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, INPUT_SIZE)
        img = img.astype(np.float32) / 255.0  # 归一化
        img = np.expand_dims(img, axis=0)
        yield [img]

# 创建转换器
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
    tf.lite.OpsSet.TFLITE_BUILTINS
]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

# 转换
tflite_model = converter.convert()

# 保存
with open(OUTPUT_TFLITE_PATH, 'wb') as f:
    f.write(tflite_model)

print(f"int8量化模型已保存到: {OUTPUT_TFLITE_PATH}") 

        运行完会生成model_int8.tflite

        然后编写测试tflite的代码test_int8_tflite.py,代码如下

import tensorflow as tf
import numpy as np
import cv2
import sys

MODEL_PATH = 'model_int8.tflite'
IMAGE_PATH = 'imge_17875.jpg'  # 测试图片路径
CLASS_NAMES = ["Blight", "Brownspot", "Insect", "Powderymildew", "Rust", "Sootymold"]
INPUT_SIZE = (320, 320)
THRESHOLD = 0.6

# 预处理图片为uint8
def preprocess_image(image_path, input_size):
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"无法读取图像: {image_path}")
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img_rgb, input_size)
    img_norm = img_resized / 255.0
    img_uint8 = (img_norm * 255).astype(np.uint8)
    img_uint8 = np.expand_dims(img_uint8, axis=0)
    return img, img_uint8

# 加载模型
def load_interpreter(model_path):
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    return interpreter

# 推理
def run_inference(interpreter, input_tensor):
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    interpreter.set_tensor(input_details[0]['index'], input_tensor)
    interpreter.invoke()
    outputs = [interpreter.get_tensor(od['index']) for od in output_details]
    return outputs, output_details

# 主流程
def main():
    img, input_tensor = preprocess_image(IMAGE_PATH, INPUT_SIZE)
    interpreter = load_interpreter(MODEL_PATH)
    outputs, output_details = run_inference(interpreter, input_tensor)

    print('模型输出张量数量:', len(outputs))
    for i, (out, od) in enumerate(zip(outputs, output_details)):
        print(f'输出{i}: 名称={od["name"]}, 形状={out.shape}, dtype={out.dtype}')

    # 反量化所有输出
    outputs_float = []
    for i, (out, od) in enumerate(zip(outputs, output_details)):
        scale, zero_point = od['quantization']
        if scale > 0:
            out_float = (out.astype(np.float32) - zero_point) * scale
        else:
            out_float = out.astype(np.float32)
        outputs_float.append(out_float)
        print(f'输出{i}反量化后(前10):', out_float.flatten()[:10])

    # 解析检测结果
    # 假设: outputs_float[0]=scores, [1]=boxes, [2]=num, [3]=classes
    scores = outputs_float[0][0]  # (10,)
    boxes = outputs_float[1][0]   # (10,4)
    num = int(outputs_float[2][0])
    classes = outputs_float[3][0].astype(int)  # (10,)

    print(f"检测到目标数量: {num}")
    img_vis = img.copy()
    h, w = img.shape[:2]
    for i in range(num):
        score = scores[i]
        if score < THRESHOLD:
            continue
        box = boxes[i]  # [ymin, xmin, ymax, xmax] 归一化
        cls = classes[i]
        class_name = CLASS_NAMES[cls] if 0 <= cls < len(CLASS_NAMES) else str(cls)
        ymin, xmin, ymax, xmax = box
        # 坐标还原到原图
        left = int(xmin * w)
        top = int(ymin * h)
        right = int(xmax * w)
        bottom = int(ymax * h)
        cv2.rectangle(img_vis, (left, top), (right, bottom), (0,255,0), 2)
        label = f"{class_name}: {score:.2f}"
        cv2.putText(img_vis, label, (left, top-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
        print(f"目标{i}: 类别={class_name}, 分数={score:.2f}, 框={[left, top, right, bottom]}")

    # 保存可视化结果
    cv2.imwrite('result_vis.jpg', cv2.cvtColor(img_vis, cv2.COLOR_RGB2BGR))
    print('检测结果已保存到 result_vis.jpg')

if __name__ == '__main__':
    main() 

        运行结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值