上一节讲的cls的创建dataset还是使用的处理复制文件从而直接导入的傻瓜方法,该类方法在后期数据集特别大时会造成内存的重复读取耗时问题。且目标检测领域dataset类是避不开的,所以针对沐神的代码进行dataset从0开始搭建!
目录
1.目标检测数据集--banana
1.1描述
给的文件中,有train与val两个文件,这两个文件里面都是右边这张图所示:


其中label里面如下图所示:第一列写着images里面的图片名称,后五列分别为图片包含目标的种类,该目标左上角x,y坐标与右下角xy坐标:

1.2读取文件:
代码如下图所示:
data_dir = '/CV/xhr/banana-detection'
def read_data_bananas(is_train=True):
"""读取⾹蕉检测数据集中的图像和标签"""
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')))
# 这⾥的target包含(类别,左上⻆x,左上⻆y,右下⻆x,右下⻆y),
# 其中所有图像都具有相同的⾹蕉类(索引为0)
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
关于里面的set_index('image_name'),其是配套后面的for loop里面的.iterrows使用的,set_index表示讲img_name这一列的元素替换掉原始的索引列,如下图所示:使用该命令前:

使用该命令后:可以看到原先索引标号0.1.2...替换成了img_name列中的各个元素。

为什么这么处理?
因为要有for loop里面的.iterrows()操作,该操作返回的两个值分别为每一行的索引与每一行剩下的元素:其中这个剩下的多个元素是pandas的Series数据结构,可以用list换成列表。

在注意里面这个torchvision.io.read_image(path)操作是将path中的图片(支持JPEG,PNG,BMP,GIF,TIFF等)从磁盘上读取并转换成tensor张量,输出为(C,H,W)
1.3 读取函数最终返回值
让我们看一下这个函数最终返回的是啥:

这里让a等于后面这一长串,返回的images是一个有1000个元素的list(一共有1000张图片),其中的每一个元素为一个tensor的图片(C,H,W),另一个值a是一个tensor(1000, m, 5),这里的m表示一张图设有几个目标物体,这里每张图都只有一个,所以m设为1.
1.4 构造dataset类
注意,这里面的__getitem__()最后索引的features[idx]与labels[idx]就是1.3那一张图里面对应的东西,即返回的是每一张idx对应的图片的(c,h,w).float与该图片对应的标注(m, 5)。
class BananasDataset(torch.utils.data.Dataset):
"""⼀个⽤于加载⾹蕉检测数据集的⾃定义数据集"""
def __init__(self, is_train):
self.features, self.labe

文章详细介绍了如何从零开始构建目标检测数据集,包括读取CSV文件,处理图像和标签,构建自定义Dataset类,以及使用Dataloader加载数据。此外,还展示了对Kaggle树叶分类数据集的处理,包括将类别标签转换为数字,构建数据加载器,以及用于训练的网络模型。整个过程涵盖了数据预处理、数据加载和网络构建的关键步骤。
最低0.47元/天 解锁文章
2586





