Tensorflow2.0 YOLO篇之提取xml文件信息

本文深入解析了使用TensorFlow 2.0进行YOLO算法的实现过程,从数据集介绍到XML文件解析,再到图像预处理及模型训练。详细介绍了如何处理甜菜和杂草的数据集,包括提取图片中的物体位置和类别信息,并将其转化为可用于模型训练的格式。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Tensorflow2.0 YOLO篇之提取xml文件信息



数据集介绍

数据集下载地址:

链接:https://pan.baidu.com/s/1ZP9H2ym3Vp4Sda1mNiv9Pw 
提取码:5okb 
复制这段内容后打开百度网盘手机App,操作更方便哦

这次选择的数据集是甜菜(sugarbeet)和杂草(weed)的数据集
在这里插入图片描述
在数据集的xml文件中包含了图片中物体的位置形状(x,y,w,h)和label
其中的一个xml文件

<annotation>
	<folder>train</folder>
	<filename>X2-10-1.png</filename>
	<path /><source>
		<database>Unknown</database>
	</source>
	<size>
		<width>512</width>
		<height>512</height>
		<depth>3</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<name>weed</name>
		<pose>Unspecified</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>71</xmin>
			<ymin>265</ymin>
			<xmax>115</xmax>
			<ymax>278</ymax>
		</bndbox>
	</object>
	
	......

	<object>
		<name>sugarbeet</name>
		<pose>Unspecified</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>322</xmin>
			<ymin>266</ymin>
			<xmax>363</xmax>
			<ymax>294</ymax>
		</bndbox>
	</object>
</annotation>

现在我们要做的工作就是将这些数据储存到numpy数组中去,代码中我尽可能的写了注释,书写这个的是否选择了vscode作为编译工具,以为vscode对于jupyter的支持较好,可以在编写的过程中更加方便的查看每一步的运行结果

同时在编写这一步的时候需要注意的一个点就是每个图片中的物体个数可能不一样,这样我们的boxes的个数就有问题。因为每个图片中的框信息都没有超过5个(上图除外那是我自己画的),所以我们每一张图片都涉及有五个空,不足的就用0来填充

#%%
import tensorflow as tf
import os,glob
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
from tensorflow import keras
# set random seed
tf.random.set_seed(2233)
np.random.seed(2233)

# %%
print(tf.__version__)
print(tf.text.is_gpu_available())



# %%
import xml.etree.ElementTree as ET

def parse_annotation(img_dir,ann_dir,labels):
    # parse annotation and save is into numpy array
    # img_dir: image path
    # ann_dir: annotation xml file path
    # labels: ('sugarweet','weed')
    imgs_info =[]
    # for each annotation xml file
    max_boxes = 0
    for ann in os.listdir(ann_dir):
        tree = ET.parse(os.path.join(ann_dir,ann))

        img_info = dict()
        img_info['object'] = []
        boxes_counter = 0
        for elem in tree.iter():
            if 'filename' in elem.tag:
                img_info['filename'] = os.path.join(img_dir,elem.text)
            if 'width' in elem.tag:
                img_info['width'] = int(elem.text)
                assert img_info['width'] == 512
            if 'height' in elem.tag:
                img_info['height'] = int(elem.text)
                assert img_info['width'] == 512
            if 'object' in elem.tag or 'part' in elem.tag:
                # x1-y1-x2-y2-label
                object_info =  [0,0,0,0,0]
                boxes_counter += 1
                for attr in list(elem):
                    # add image info into object_info
                    if 'name' in attr.tag:
                        label = labels.index(attr.text) + 1
                        object_info[4] = label
                    if 'bndbox' in attr.tag:
                        for pos in list(attr):
                            if 'xmin' in pos.tag:
                                object_info[0] = int(pos.text)
                            if 'ymin' in pos.tag:
                                object_info[1] = int(pos.text)
                            if 'xmax' in pos.tag:
                                object_info[2] = int(pos.text)
                            if 'ymax' in pos.tag:
                                object_info[3] = int(pos.text)
                img_info['object'].append(object_info)
        imgs_info.append(img_info) # filename,w/h/box_info
        # (N,5) = (max_objects_num,5) 5 is x-y-w-h-label
        if boxes_counter > max_boxes:
            max_boxes = boxes_counter
    # the maximum boxes number is max_boxes
    # [b,max_things,5]
    boxes = np.zeros([len(imgs_info),max_boxes,5])
    imgs = [] # filename last
    for i,img_info in enumerate(imgs_info):
        # [N,5] N: boxes number
        img_boxes = np.array(img_info['object'])
        # overwrite the N boxes info
        boxes[i,:img_boxes.shape[0]] = img_boxes
        imgs.append(img_info['filename'])
        print(img_info['filename'],boxes[i,:5])
    # imgs: list of image path
    # boxes:[b,40,5]
    return imgs,boxes

# %%
obj_names = ('sugarbeet','weed')
imgs,boxes = parse_annotation('data/train/image','data/train/annotation',obj_names)

参考书籍: TensorFlow 深度学习 — 龙龙老师

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值