YOLO-tensorflow代码解析一(pascal_voc)

本文主要对YOLO目标检测算法的Tensorflow实现进行解析,聚焦于Pascal_VOC数据集的应用。内容涵盖个人对代码的理解,可能存在误差,欢迎读者指正。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import os
import xml.etree.ElementTree as ET
import numpy as np
import cv2
import pickle
import copy
import yolo.config as cfg


class pascal_voc(object):
    def __init__(self, phase, rebuild=False):
        self.devkil_path = os.path.join(cfg.PASCAL_PATH, 'VOCdevkit')
        self.data_path = os.path.join(self.devkil_path, 'VOC2007')
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = cfg.BATCH_SIZE
        self.image_size = cfg.IMAGE_SIZE
        self.cell_size = cfg.CELL_SIZE
        self.classes = cfg.CLASSES
        self.class_to_ind = dict(zip(self.classes, range(len(self.classes))))
        self.flipped = cfg.FLIPPED
        self.phase = phase
        self.rebuild = rebuild
        self.cursor = 0
        self.epoch = 1
        self.gt_labels = None
          #开始为空但是在实例初始化的时候就调用了prepare(),所以每次初始化都会有label
        self.prepare()

    # 批量读取图片和图片的标签信息
    def get(self):
        images = np.zeros(
            (self.batch_size, self.image_size, self.image_size, 3))
        labels = np.zeros(
            (self.batch_size, self.cell_size, self.cell_size, 25))
        count = 0
        while count < self.batch_size:  #读取一个批次的图片信息
            imname = self.gt_labels[self.cursor]['imname']
            flipped = self.gt_labels[self.cursor]['flipped']
            images[count, :, :, :] = self.image_read(imname, flipped)
            labels[count, :, :, :] = self.gt_labels[self.cursor]['label']
            count += 1
            self.cursor += 1
            if self.cursor >= len(self.gt_labels):  #一个批次读取完毕打乱顺序
                np.random.shuffle(self.gt_labels)
                self.cursor = 0
                self.epoch += 1
        return images, labels

    # 对图片数据进行格式转换以及归一化等处理
    def image_read(self, imname, flipped=False):
        image = cv2.imread(imname)
        image = cv2.resize(image, (self.image_size, self.image_size))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)#转化为灰度图像0-255
        image = (image / 255.0) * 2.0 - 1.0 #像素强度归一化
        if flipped: #水平翻转
            image = image[:, ::-1, :]
        return image

    # 准备好数据的标签信息
    def prepare(self):
        gt_labels = self.load_labels()
        if self.flipped:
            print('Appending horizontally-flipped training examples ...')#增加水平翻转训练示例
            gt_labels_cp = copy.deepcopy(gt_labels)
            for idx in range(len(gt_labels_cp)): #遍历所有flipped=Ture的图片label
                gt_labels_cp[idx]['flipped'] = True
                gt_labels_cp[idx]['label'] =\
                    gt_labels_cp[idx]['label'][:, ::-1, :] #将列向量倒序,相当于水平翻转
                for i in range(self.cell_size):
                    for j in range(self.cell_size):   #遍历整个cell
                        if gt_labels_cp[idx]['label'][i, j, 0] == 1: #改变包含物体cell box的横坐标,水平反转后相当于size-原来
                            gt_labels_cp[idx]['label'][i, j, 1] = \
                                self.image_size - 1 -\
                                gt_labels_cp[idx]['label'][i, j, 1]
            gt_labels += gt_labels_cp  #将新增的加入list
        np.random.shuffle(gt_labels)
        self.gt_labels = gt_labels #随机打乱标签
        return gt_labels

    #制作标签,为一个列表,列表的每一个元素为一个字典包含一张图片的三个属性:imname、label、flipped
    def load_labels(self):
        cache_file = os.path.join(
            self.cache_path, 'pascal_' + self.phase + '_gt_labels.pkl')

        if os.path.isfile(cache_file) and not self.rebuild: #如果路径下存在标签文件直接读取,文件为一个列表,其中每个图片为一个字典
            print('Loading gt_labels from: ' + cache_file)  #字典有三个关键字:imname、label、flipped
            with open(cache_file, 'rb') as f:
                gt_labels = pickle.load(f)
            return gt_labels

        print('Processing gt_labels from: ' + self.data_path)

        if not os.path.exists(self.cache_path): #如果路径下不存在文件夹,新建文件夹
            os.makedirs(self.cache_path)

        if self.phase == 'train':
            txtname = os.path.join(
                self.data_path, 'ImageSets', 'Main', 'trainval.txt')    #这个文件中保存着文件的名字
        else:
            txtname = os.path.join(
                self.data_path, 'ImageSets', 'Main', 'test.txt')
        with open(txtname, 'r') as f:
            self.image_index = [x.strip() for x in f.readlines()]

        gt_labels = []
        for index in self.image_index:
            label, num = self.load_pascal_annotation(index) #读取7*7*25的ground truth
            if num == 0:
                continue
            imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg') #带路径的图片名称
            gt_labels.append({'imname': imname,
                              'label': label,
                              'flipped': False})  #将图片名称、label、flipped打包成一个字典
        print('Saving gt_labels to: ' + cache_file)
        with open(cache_file, 'wb') as f:
            pickle.dump(gt_labels, f)
        return gt_labels
    #从annotion文件中读取一个7*7*25的label
    def load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """

        imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
        im = cv2.imread(imname)     #读取图片
        h_ratio = 1.0 * self.image_size / im.shape[0]#计算图片和默认图片大小448的比值
        w_ratio = 1.0 * self.image_size / im.shape[1]
        # im = cv2.resize(im, [self.image_size, self.image_size])

        label = np.zeros((self.cell_size, self.cell_size, 25))
        filename = os.path.join(self.data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')

        for obj in objs:
            bbox = obj.find('bndbox')
            # 对于每一个boundingbox找到在原图中的坐标,由于图像和446大小不一,因此做一个尺度的变换
            x1 = max(min((float(bbox.find('xmin').text) - 1) * w_ratio, self.image_size - 1), 0)
            y1 = max(min((float(bbox.find('ymin').text) - 1) * h_ratio, self.image_size - 1), 0)
            x2 = max(min((float(bbox.find('xmax').text) - 1) * w_ratio, self.image_size - 1), 0)
            y2 = max(min((float(bbox.find('ymax').text) - 1) * h_ratio, self.image_size - 1), 0)
            cls_ind = self.class_to_ind[obj.find('name').text.lower().strip()]
            boxes = [(x2 + x1) / 2.0, (y2 + y1) / 2.0, x2 - x1, y2 - y1]    #转换成中心坐标和长宽比的形式
            #计算中心点在7*7feature map中的位置,强制转化为int型
            x_ind = int(boxes[0] * self.cell_size / self.image_size)
            y_ind = int(boxes[1] * self.cell_size / self.image_size)
            if label[y_ind, x_ind, 0] == 1:
                continue
            label[y_ind, x_ind, 0] = 1  #cell负责box的置信度为1,表示有物体
            label[y_ind, x_ind, 1:5] = boxes    #后四个维度放入坐标,为中心点的形式
            label[y_ind, x_ind, 5 + cls_ind] = 1 #采用onehot模式表示类别

        return label, len(objs)

a = pascal_voc("train")

b = a.get() #我们只要调用get函数就可以得到打包好的label信息了,作者在train.py也是直接调用get


# gt_labels_cp = copy.deepcopy(b)
# gt_labels_cp[0]['label']=gt_labels_cp[0]['label'][:, ::-1, :]
# print(gt_labels_cp[0]['label'].shape)
;

以上均为个人理解,仅供参考,如有错误欢迎纠正,非常感谢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值