概要
- 路径:
- OWOD/datasets/coco_utils/balanced_ft.py
- OWOD/detectron2/utils/store_non_list.py
- 作用:从train.txt数据文件中随机抽取balanced set of exemplars,存储为tn_ft_xx.txt文件(xx为每一类的样本数量)
整体架构流程
- 列出数据集中的类的名称 (不同task)
- 设置路径、参数(items_per_class)
- 创建Store实例image_store 作为采样结果
- 解析xml文件(annotation文件)将每一类选够items_per_class个样本
- 因为结果是map类定义的str形式,所以需要用retrieve获取完整列表格式并存储
- 删除重复项
- 用map函数映射成带回车的序列
- 写入文件
代码解析
import itertools
import random
import os
import xml.etree.ElementTree as ET
from fvcore.common.file_io import PathManager
from detectron2.utils.store_non_list import Store
VOC_CLASS_NAMES_COCOFIED = [
"airplane", "dining table", "motorcycle",
"potted plant", "couch", "tv"
]
BASE_VOC_CLASS_NAMES = [
"aeroplane", "diningtable", "motorbike",
"pottedplant", "sofa", "tvmonitor"
]
VOC_CLASS_NAMES = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
T2_CLASS_NAMES = [
"truck", "traffic light", "fire hydrant", "stop sign", "parking meter",
"bench", "elephant", "bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag", "tie", "suitcase",
"microwave", "oven", "toaster", "sink", "refrigerator"
]
T3_CLASS_NAMES = [
"frisbee", "skis", "snowboard", "sports ball", "kite",
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake"
]
T4_CLASS_NAMES = [
"bed", "toilet", "laptop", "mouse",
"remote", "keyboard", "cell phone", "book", "clock",
"vase", "scissors", "teddy bear", "hair drier", "toothbrush",
"wine glass", "cup", "fork", "knife", "spoon", "bowl"
]
UNK_CLASS = ["unknown"]
# Change this accodingly for each task t*
known_classes = list(itertools.chain(VOC_CLASS_NAMES, T2_CLASS_NAMES)) # 拼接task1 和 task2 的classes
train_files = ['/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/t2_train.txt','/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/t1_train.txt'] # training数据路径
# known_classes = list(itertools.chain(VOC_CLASS_NAMES))
# train_files = ['/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/train.txt']
annotation_location = '/home/fk1/workspace/OWOD/datasets/VOC2007/Annotations' # 标签路径
items_per_class = 20 # 每个class取20个样本
#dest_file = '/home/fk1/workspace/OWOD/datasets/VOC2007/ImageSets/Main/t2_ft_' + str(items_per_class) + '.txt' # 输出路径
dest_file = 'OWOD-master/output/20230703/datasets/VOC2007/ImageSets/Main/t2_ft_' + str(items_per_class) + '.txt' # 输出路径
file_names = []
for tf in train_files:
with open(tf, mode="r") as myFile:
file_names.extend(myFile.readlines()) # 将train数据的每一行拼接为列表
random.shuffle(file_names) # 打乱顺序
image_store = Store(len(known_classes), items_per_class) # 随机抽样 每类items_per_class个
current_min_item_count = 0
for fileid in file_names: # 对于每个train文件
fileid = fileid.strip() # 删除头尾空字符
anno_file = os.path.join(annotation_location, fileid + ".xml") # anno文件路径
with PathManager.open(anno_file) as f: # 打开XML文件
tree = ET.parse(f) # 解析XML文件
for obj in tree.findall("object"):
cls = obj.find("name").text
if cls in VOC_CLASS_NAMES_COCOFIED:
cls = BASE_VOC_CLASS_NAMES[VOC_CLASS_NAMES_COCOFIED.index(cls)] # 如果是VOC_CLASS_NAMES_COCOFIED中的类则用其中的label替换
if cls in known_classes:
image_store.add((fileid,), (known_classes.index(cls),)) # 根据索引和fileid添加该样本的内容
current_min_item_count = min([len(items) for items in image_store.retrieve(-1)])
print(current_min_item_count)
if current_min_item_count == items_per_class: # 一直采样直到最少的样本数达到需要的20个
break
filtered_file_names = []
for items in image_store.retrieve(-1): # 获取并存储完整列表
filtered_file_names.extend(items)
print(image_store)
print(len(filtered_file_names))
print(len(set(filtered_file_names))) # 删除重复项
filtered_file_names = set(filtered_file_names)
filtered_file_names = map(lambda x: x + '\n', filtered_file_names) # 映射成序列
with open(dest_file, mode="w") as myFile:
myFile.writelines(filtered_file_names) # 写入文件
print('Saved to file: ' + dest_file)
小结
关键方法:
- 使用Class Store
- 使用set函数去除重复项
- 使用map函数