数据预处理和数据集的设置——以目标检测数据集为例

本文介绍如何根据PyTorch官方文档创建自定义遥感影像数据集,涵盖数据准备、数据集类编写及可视化展示等内容。

数据集

在网上有很多可用的公开的数据集,根据自己的需要,下载相应的数据集,可以用来训练网络,测试网络模型的精度。

[数据集转载来源] 深度学习中的遥感影像数据集

Pascal VOC网址http://host.robots.ox.ac.uk/pascal/VOC/

转载的一篇包含了比较多的数据集的一篇博文,可以参考一下。

但有些时候,我们需要根据我们自己的需求,根据自己的研究方向和类型,设置自己的数据集,以下,简单的阐述了设置数据集的一些步骤。

创建数据集

在pytorch中,官方文档简单的介绍了创建数据集的简单步骤。

# ================================================================== #
#                5. Input pipeline for custom dataset                 #
# ================================================================== #

# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        # 设置文件和标签的路径,或者文件名list,最关键的就是设置好数据集的路径,以及初始化一些数据集的属性
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        # 通过上述的数据集路径,读取文件,并且对文件进行预处理操作,返回真实的文件数据,比如image and label
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        # 比较简单,只是设置数据集的长度,返回一个值
        return 0 

# You can then use the prebuilt data loader. 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

所以说,最关键的就是初始化文件路径和读取文件,以及文件的预处理。

其他的一些需要用到的属性和方法,在需要的时候加上就行。比如如何进行数据读取、如何进行预处理等。

在实际应用中,创建数据集的基本步骤也大致如此,只需要把相应的方法写全即可,下面以目标检测的数据集为例。

栗子

1.数据准备

首先,我们拿到目标检测的遥感图像,放到一个总的文件夹中。再使用标签工具labelImg进行标注,将标注好的xml标签文件同样放到同一个标签文件夹中。(下图仅为部分数据的截图)

这里有个小问题,就是使用不同的标注工具,得到的bonding box的格式会有不同,在后期读取的时候,可能会报错。

在这里插入图片描述
在这里插入图片描述

以下是图像和标签数据的截图实例:

在这里插入图片描述 在这里插入图片描述

再创建一个类别文件,设置不同的分类的地物名称,以及一个类别对应的JSON文件,不同类别对应不同的key和value。

在这里插入图片描述 在这里插入图片描述
将上述文件都放在同一个文件夹中,再将这些数据随机分成训练集和测试集,代码如下。

import os
import random


def train_val_txt(files_path,val_rate,output_train_path,output_val_path):
    '''
    :param files_path: 保存的所有图片文件的目录
    :param val_rate: 选择测试集相对于总体的比率
    :param output_train_path: 输出的train的filename的txt目录
    :param output_val_path: 输出的val的filename的txt目录
    '''

    if not os.path.exists(files_path):
        print("文件夹不存在")
        exit(1)

    # 获取文件目录下的所有文件名,返回列表格式
    files_name = sorted([file.split('.')[0] for file in os.listdir(files_path)])

    files_num = len(files_name)

    # 设置采样的序号,从[0,files_num] 中随机抽取k个数
    val_index = random.sample(range(0, files_num), k=int(files_num * val_rate))
    train_files = []
    val_files = []
    for index, file_name in enumerate(files_name):
        if index in val_index:
            val_files.append(file_name)
        else:
            train_files.append(file_name)

    try:
        with open(output_train_path,'x') as f:
            f.write('\n'.join(train_files))
        with open(output_val_path, 'x') as f:
            f.write('\n'.join(val_files))
    except Exception as e:
        print(e)
        exit(1)

根据注释,设置路径和分类比,运行后可以得到train.txt和val.txt文件

文本文件中保存着训练集或测试集的样本名称,在后续操作中,直接读取不同的样本名称,就可以加载不同的数据。

最终效果如下:

在这里插入图片描述

这样数据就准备好了。

2.设置数据集

按照官方文档的框架,自定义数据集。

在init中,主要是初始化用户数据集的目录,包括设置标签目录,遥感影像目录,以及预处理。

def __init__(self, data_root, transforms, train=True):
    #设置不同的路径,分别设置成图片路径和标签路径
    self.root = os.path.join(data_root, "data")
    self.img_root = os.path.join(self.root, "JPEGImages")
    self.annotations_root = os.path.join(self.root, "Annotations")

    """读取训练集/测试集,txt_list是路径"""
    if train:
        txt_list = os.path.join(self.root, "ImageSets", "Main", "train_1.txt")
    else:
        txt_list = os.path.join(self.root, "ImageSets", "Main", "val_1.txt")

    with open(txt_list) as read:
        self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                         for line in read.readlines()]

    # 读取分类索引
    try:
        json_file = open('./data/classes.json', 'r')
        self.class_dict = json.load(json_file)
    except Exception as e:
        print(e)
        exit(-1)

    # 定义预处理方式
    self.transforms = transforms 

len方法主要是返回数据集的个数,即有多少张图像(图像和标签是对应的)。该方法比较简单,直接返回即可。

def __len__(self):
    """返回训练集/测试集中图片的个数"""
    return len(self.xml_list)

在getitem中,传入index,即对不同index的图像和标签进行处理,返回一个image和target(包含boxes、label、image_id等信息)。

对于不同的需求,设置不同的方法,这里只是以目标检测为例,故需要返回image、label和boxes边界框等信息。

def __getitem__(self, idx):
    # read xml
    xml_path = self.xml_list[idx]  # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str
    with open(xml_path) as fid:
        xml_str = fid.read()
    # xml = etree.fromstring(xml_str)
    xml = etree.fromstring(xml_str.encode('utf-8'))  # 读取xml文件的内容
    data = self.parse_xml_to_dict(xml)["annotation"]
    img_path = os.path.join(self.img_root, data["filename"])  # 从xml文件中得到img文件路径
    image = Image.open(img_path)
    if image.format != "JPEG":
        raise ValueError("Image format not JPEG")
    boxes = []
    labels = []
    iscrowd = []  # 是否难检测,crowd为0表示单目标
    for obj in data["object"]:
        """得到训练集边框坐标,分类和难易程度"""
        xmin = float(obj["bndbox"]["xmin"])
        xmax = float(obj["bndbox"]["xmax"])
        ymin = float(obj["bndbox"]["ymin"])
        ymax = float(obj["bndbox"]["ymax"])
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(self.class_dict[obj["name"]])
        iscrowd.append(int(obj["difficult"]))

    # convert everything into a torch.Tensor
    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    labels = torch.as_tensor(labels, dtype=torch.int64)
    iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
    image_id = torch.tensor([idx])  # 当前数据对应
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值