利用目标检测模型标注数据以构建xml文件
代码如下:
weights = 'pt/last.pt'
det_model = ObjectDet(weights, input_size=(800, 800), conf_thres=0.5)
data_dir = 'datasets/20220302/'
names = ["two_wheel", "three_wheel", "car", "person"]
folder_name = "20220302"
database = "Unknown"
for file_name in os.listdir(data_dir):
if file_name.endswith(".xml"):
continue
jpg_name = file_name
print(jpg_name)
xml_name = jpg_name[:-4] + ".xml"
jpg_path = os.path.join(data_dir, jpg_name)
xml_path = os.path.join(data_dir, xml_name)
img = cv2.imread(jpg_path)
height, width, channel = img.shape
// 坐标, 类别,得分
res = det_model.det(jpg_path)
res = res.to("cpu").numpy()
if res is not None:
dom = xml.dom.getDOMImplementation()
doc = dom.createDocument(None, "annotation", None)
root_node = doc.documentElement
folder_node = doc.createElement("folder")
text_node = doc.createTextNode(folder_name)
folder_node.appendChild(text_node)
root_node.appendChild(folder_node)
filename_node = doc.createElement("filename")
text_node = doc.createTextNode(jpg_name)
filename_node.appendChild(text_node)
root_node.appendChild(filename_node)
path_node = doc.createElement("path")
text_node = doc.createTextNode(jpg_path)
path_node.appendChild(text_node)
root_node.appendChild(path_node)
source_node = doc.createElement('source')
database_node = doc.createElement("database")
text_node = doc.createTextNode(database)
database_node.appendChild(text_node)
source_node.appendChild(database_node)
root_node.appendChild(source_node)
size_node = doc.createElement('size')
width_node = doc.createElement("width")
text_node = doc.createTextNode(str(width))
width_node.appendChild(text_node)
size_node.appendChild(width_node)
height_node = doc.createElement("height")
text_node = doc.createTextNode(str(height))
height_node.appendChild(text_node)
size_node.appendChild(height_node)
depth_node = doc.createElement("depth")
text_node = doc.createTextNode(str(channel))
depth_node.appendChild(text_node)
size_node.appendChild(depth_node)
root_node.appendChild(size_node)
segmented_node = doc.createElement("segmented")
text_node = doc.createTextNode("0")
segmented_node.appendChild(text_node)
root_node.appendChild(segmented_node)
for *xyxy, conf, cls in reversed(res):
object_node = doc.createElement('object')
name_node = doc.createElement("name")
text_node = doc.createTextNode(names[int(cls)])
name_node.appendChild(text_node)
object_node.appendChild(name_node)
pose_node = doc.createElement("pose")
text_node = doc.createTextNode("Unspecified")
pose_node.appendChild(text_node)
object_node.appendChild(pose_node)
truncated_node = doc.createElement("truncated")
text_node = doc.createTextNode("0")
truncated_node.appendChild(text_node)
object_node.appendChild(truncated_node)
difficult_node = doc.createElement("difficult")
text_node = doc.createTextNode("0")
difficult_node.appendChild(text_node)
object_node.appendChild(difficult_node)
bndbox_node = doc.createElement("bndbox")
xmin_node = doc.createElement("xmin")
text_node = doc.createTextNode(str(xyxy[0]))
xmin_node.appendChild(text_node)
bndbox_node.appendChild(xmin_node)
ymin_node = doc.createElement("ymin")
text_node = doc.createTextNode(str(xyxy[1]))
ymin_node.appendChild(text_node)
bndbox_node.appendChild(ymin_node)
xmax_node = doc.createElement("xmax")
text_node = doc.createTextNode(str(xyxy[2]))
xmax_node.appendChild(text_node)
bndbox_node.appendChild(xmax_node)
ymax_node = doc.createElement("ymax")
text_node = doc.createTextNode(str(xyxy[3]))
ymax_node.appendChild(text_node)
bndbox_node.appendChild(ymax_node)
object_node.appendChild(bndbox_node)
root_node.appendChild(object_node)
xml_file = open(xml_path, 'w')
doc.writexml(xml_file, addindent='\t', newl='\n', encoding='utf-8')
xml_file.close()