数据集
在网上有很多可用的公开的数据集,根据自己的需要,下载相应的数据集,可以用来训练网络,测试网络模型的精度。
[数据集转载来源] 深度学习中的遥感影像数据集
Pascal VOC网址:http://host.robots.ox.ac.uk/pascal/VOC/
转载的一篇包含了比较多的数据集的一篇博文,可以参考一下。
但有些时候,我们需要根据我们自己的需求,根据自己的研究方向和类型,设置自己的数据集,以下,简单的阐述了设置数据集的一些步骤。
创建数据集
在pytorch中,官方文档简单的介绍了创建数据集的简单步骤。
# ================================================================== #
# 5. Input pipeline for custom dataset #
# ================================================================== #
# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# TODO
# 1. Initialize file paths or a list of file names.
# 设置文件和标签的路径,或者文件名list,最关键的就是设置好数据集的路径,以及初始化一些数据集的属性
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
# 通过上述的数据集路径,读取文件,并且对文件进行预处理操作,返回真实的文件数据,比如image and label
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
# 比较简单,只是设置数据集的长度,返回一个值
return 0
# You can then use the prebuilt data loader.
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
batch_size=64,
shuffle=True)
所以说,最关键的就是初始化文件路径和读取文件,以及文件的预处理。
其他的一些需要用到的属性和方法,在需要的时候加上就行。比如如何进行数据读取、如何进行预处理等。
在实际应用中,创建数据集的基本步骤也大致如此,只需要把相应的方法写全即可,下面以目标检测的数据集为例。
栗子
1.数据准备
首先,我们拿到目标检测的遥感图像,放到一个总的文件夹中。再使用标签工具labelImg进行标注,将标注好的xml标签文件同样放到同一个标签文件夹中。(下图仅为部分数据的截图)
这里有个小问题,就是使用不同的标注工具,得到的bonding box的格式会有不同,在后期读取的时候,可能会报错。


以下是图像和标签数据的截图实例:

再创建一个类别文件,设置不同的分类的地物名称,以及一个类别对应的JSON文件,不同类别对应不同的key和value。

将上述文件都放在同一个文件夹中,再将这些数据随机分成训练集和测试集,代码如下。
import os
import random
def train_val_txt(files_path,val_rate,output_train_path,output_val_path):
'''
:param files_path: 保存的所有图片文件的目录
:param val_rate: 选择测试集相对于总体的比率
:param output_train_path: 输出的train的filename的txt目录
:param output_val_path: 输出的val的filename的txt目录
'''
if not os.path.exists(files_path):
print("文件夹不存在")
exit(1)
# 获取文件目录下的所有文件名,返回列表格式
files_name = sorted([file.split('.')[0] for file in os.listdir(files_path)])
files_num = len(files_name)
# 设置采样的序号,从[0,files_num] 中随机抽取k个数
val_index = random.sample(range(0, files_num), k=int(files_num * val_rate))
train_files = []
val_files = []
for index, file_name in enumerate(files_name):
if index in val_index:
val_files.append(file_name)
else:
train_files.append(file_name)
try:
with open(output_train_path,'x') as f:
f.write('\n'.join(train_files))
with open(output_val_path, 'x') as f:
f.write('\n'.join(val_files))
except Exception as e:
print(e)
exit(1)
根据注释,设置路径和分类比,运行后可以得到train.txt和val.txt文件。
文本文件中保存着训练集或测试集的样本名称,在后续操作中,直接读取不同的样本名称,就可以加载不同的数据。
最终效果如下:

这样数据就准备好了。
2.设置数据集
按照官方文档的框架,自定义数据集。
在init中,主要是初始化用户数据集的目录,包括设置标签目录,遥感影像目录,以及预处理。
def __init__(self, data_root, transforms, train=True):
#设置不同的路径,分别设置成图片路径和标签路径
self.root = os.path.join(data_root, "data")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
"""读取训练集/测试集,txt_list是路径"""
if train:
txt_list = os.path.join(self.root, "ImageSets", "Main", "train_1.txt")
else:
txt_list = os.path.join(self.root, "ImageSets", "Main", "val_1.txt")
with open(txt_list) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines()]
# 读取分类索引
try:
json_file = open('./data/classes.json', 'r')
self.class_dict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# 定义预处理方式
self.transforms = transforms
len方法主要是返回数据集的个数,即有多少张图像(图像和标签是对应的)。该方法比较简单,直接返回即可。
def __len__(self):
"""返回训练集/测试集中图片的个数"""
return len(self.xml_list)
在getitem中,传入index,即对不同index的图像和标签进行处理,返回一个image和target(包含boxes、label、image_id等信息)。
对于不同的需求,设置不同的方法,这里只是以目标检测为例,故需要返回image、label和boxes边界框等信息。
def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx] # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str
with open(xml_path) as fid:
xml_str = fid.read()
# xml = etree.fromstring(xml_str)
xml = etree.fromstring(xml_str.encode('utf-8')) # 读取xml文件的内容
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"]) # 从xml文件中得到img文件路径
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = [] # 是否难检测,crowd为0表示单目标
for obj in data["object"]:
"""得到训练集边框坐标,分类和难易程度"""
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx]) # 当前数据对应

本文介绍如何根据PyTorch官方文档创建自定义遥感影像数据集,涵盖数据准备、数据集类编写及可视化展示等内容。
最低0.47元/天 解锁文章
3232

被折叠的 条评论
为什么被折叠?



