【OWOD】从之前任务采样balanced set标本的代码分析

概要

  1. 路径:
    • OWOD/datasets/coco_utils/balanced_ft.py
    • OWOD/detectron2/utils/store_non_list.py
  2. 作用:从train.txt数据文件中随机抽取balanced set of exemplars,存储为tn_ft_xx.txt文件(xx为每一类的样本数量)

整体架构流程

  1. 列出数据集中的类的名称 (不同task)
  2. 设置路径、参数(items_per_class)
  3. 创建Store实例image_store 作为采样结果
  4. 解析xml文件(annotation文件)将每一类选够items_per_class个样本
    • 因为结果是map类定义的str形式,所以需要用retrieve获取完整列表格式并存储
  5. 删除重复项
  6. 用map函数映射成带回车的序列
  7. 写入文件

代码解析

import itertools
import random
import os
import xml.etree.ElementTree as ET
from fvcore.common.file_io import PathManager

from detectron2.utils.store_non_list import Store

VOC_CLASS_NAMES_COCOFIED = [
    "airplane",  "dining table", "motorcycle",
    "potted plant", "couch", "tv"
]

BASE_VOC_CLASS_NAMES = [
    "aeroplane", "diningtable", "motorbike",
    "pottedplant",  "sofa", "tvmonitor"
]

VOC_CLASS_NAMES = [
    "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
    "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]

T2_CLASS_NAMES = [
    "truck", "traffic light", "fire hydrant", "stop sign", "parking meter",
    "bench", "elephant", "bear", "zebra", "giraffe",
    "backpack", "umbrella", "handbag", "tie", "suitcase",
    "microwave", "oven", "toaster", "sink", "refrigerator"
]

T3_CLASS_NAMES = [
    "frisbee", "skis", "snowboard", "sports ball", "kite",
    "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
    "banana", "apple", "sandwich", "orange", "broccoli",
    "carrot", "hot dog", "pizza", "donut", "cake"
]

T4_CLASS_NAMES = [
    "bed", "toilet", "laptop", "mouse",
    "remote", "keyboard", "cell phone", "book", "clock",
    "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
    "wine glass", "cup", "fork", "knife", "spoon", "bowl"
]

UNK_CLASS = ["unknown"]

# Change this accodingly for each task t*
known_classes = list(itertools.chain(VOC_CLASS_NAMES, T2_CLASS_NAMES))  # 拼接task1 和 task2 的classes
train_files = ['/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/t2_train.txt','/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/t1_train.txt']  # training数据路径

# known_classes = list(itertools.chain(VOC_CLASS_NAMES))
# train_files = ['/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/train.txt']
annotation_location = '/home/fk1/workspace/OWOD/datasets/VOC2007/Annotations'  # 标签路径

items_per_class = 20  # 每个class取20个样本
#dest_file = '/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/t2_ft_' + str(items_per_class) + '.txt'  # 输出路径
dest_file = 'OWOD-master/output/20230703/datasets/VOC2007/ImageSets/Main/t2_ft_' + str(items_per_class) + '.txt'  # 输出路径

file_names = []
for tf in train_files:
    with open(tf, mode="r") as myFile:
        file_names.extend(myFile.readlines())  # 将train数据的每一行拼接为列表

random.shuffle(file_names)  # 打乱顺序

image_store = Store(len(known_classes), items_per_class)  # 随机抽样 每类items_per_class个

current_min_item_count = 0

for fileid in file_names:  # 对于每个train文件
    fileid = fileid.strip()  # 删除头尾空字符
    anno_file = os.path.join(annotation_location, fileid + ".xml")  # anno文件路径

    with PathManager.open(anno_file) as f:  # 打开XML文件
        tree = ET.parse(f)  # 解析XML文件

    for obj in tree.findall("object"):
        cls = obj.find("name").text
        if cls in VOC_CLASS_NAMES_COCOFIED:
            cls = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls)]  # 如果是VOC_CLASS_NAMES_COCOFIED中的类则用其中的label替换
        if cls in known_classes:
            image_store.add((fileid,), (known_classes.index(cls),))  # 根据索引和fileid添加该样本的内容

    current_min_item_count = min([len(items) for items in image_store.retrieve(-1)])
    print(current_min_item_count)
    if current_min_item_count == items_per_class:  # 一直采样直到最少的样本数达到需要的20个
        break

filtered_file_names = []
for items in image_store.retrieve(-1):  # 获取并存储完整列表
    filtered_file_names.extend(items)

print(image_store)
print(len(filtered_file_names))
print(len(set(filtered_file_names)))  # 删除重复项

filtered_file_names = set(filtered_file_names)
filtered_file_names = map(lambda x: x + '\n', filtered_file_names)  # 映射成序列

with open(dest_file, mode="w") as myFile:
    myFile.writelines(filtered_file_names)  # 写入文件

print('Saved to file: ' + dest_file)

小结

关键方法:

  • 使用Class Store
  • 使用set函数去除重复项
  • 使用map函数
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值