对比标注和标注图之间是否一致(每一张图可视化)
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import cv2
# ====== 输入路径(根据输出路径设置)======
dataset_dir = './data/test1'
image_dir = os.path.join(dataset_dir, 'images')
label_dir = os.path.join(dataset_dir, 'annotations')
class_file = os.path.join(dataset_dir, 'class_names.txt')
output_vis_dir = os.path.join(dataset_dir, 'vis_check')
os.makedirs(output_vis_dir, exist_ok=True)
# ====== 加载类别映射 ======
with open(class_file, 'r', encoding='utf-8') as f:
class_names = [line.strip() for line in f.readlines()]
# ====== 创建颜色映射 ======
def get_color_map_list(num_classes):
num_classes += 1
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
return np.array(color_map[3:]).reshape((-1, 3)).astype(np.uint8)
palette = get_color_map_list(256)
# ====== 可视化并检查每张图 ======
for fname in os.listdir(image_dir):
name, ext = os.path.splitext(fname)
image_path = os.path.join(image_dir, fname)
label_path = os.path.join(label_dir, name + '.png')
if not os.path.exists(label_path):
print(f"⚠️ 缺失标注图:{label_path}")
continue
# 加载图像和标注
img = Image.open(image_path).convert("RGB")
label = Image.open(label_path).convert("P")
label_np = np.array(label)
label_rgb = palette[label_np]
# 创建透明融合图
blended = Image.blend(img, Image.fromarray(label_rgb), alpha=0.8)
draw = ImageDraw.Draw(blended)
# 使用字体
try:
font = ImageFont.truetype("arial.ttf", 14)
except:
font = ImageFont.load_default()
# 标注每个类别的重心
used_class_ids = []
for class_id in np.unique(label_np):
if class_id == 0:
continue # 跳过背景
mask = (label_np == class_id).astype(np.uint8)
if mask.sum() == 0:
continue
M = cv2.moments(mask)
if M['m00'] == 0:
continue
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
class_name = class_names[class_id]
draw.text((cx, cy), class_name, fill='white', font=font)
used_class_ids.append(class_id)
# 添加图例(legend)
legend_height = 20 * len(used_class_ids)
legend_width = 200
legend = Image.new('RGB', (legend_width, legend_height), color=(255, 255, 255))
legend_draw = ImageDraw.Draw(legend)
for idx, class_id in enumerate(used_class_ids):
color = tuple(palette[class_id])
class_name = class_names[class_id]
y = idx * 20
legend_draw.rectangle([0, y, 20, y + 15], fill=color)
legend_draw.text((25, y), class_name, fill='black', font=font)
# 合并图像和图例
combined = Image.new('RGB', (blended.width + legend_width, max(blended.height, legend_height)), (255, 255, 255))
combined.paste(blended, (0, 0))
combined.paste(legend, (blended.width, 0))
# 保存可视化图
out_path = os.path.join(output_vis_dir, name + '_vis.png')
combined.save(out_path)
print(f"✅ 可视化完成: {out_path}")
print("\n🎉 检查与可视化全部完成!输出目录:", output_vis_dir)