check_dataset.py
使用方法:修改变量config_path = "/path/to/your/dataset.yaml"直接运行。
import os
import yaml
from pathlib import Path
def validate_yolo_data_config(config_path):
"""
验证 YOLO 数据配置文件的完整性,检查图像和标签文件是否匹配
Args:
config_path (str): 数据配置文件(.yaml)的路径
"""
# 读取配置文件
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
except Exception as e:
print(f"读取配置文件失败: {e}")
return
# 获取配置中的关键路径
base_path = config.get('path', '')
train_paths = config.get('train', [])
val_paths = config.get('val', [])
test_paths = config.get('test', []) # 可选
# 如果train/val是单个路径,转换为列表
if isinstance(train_paths, str):
train_paths = [train_paths]
if isinstance(val_paths, str):
val_paths = [val_paths]
if isinstance(test_paths, str):
test_paths = [test_paths]
# 验证路径是否存在
def check_paths(paths, dataset_type):
if not paths:
print(f"警告: {dataset_type} 路径为空")
return []
valid_paths = []
for path in paths:
full_path = os.path.join(base_path, path) if base_path else path
if not os.path.exists(full_path):
print(f"错误: {dataset_type} 路径不存在: {full_path}")
else:
valid_paths.append(full_path)
return valid_paths
print("\n=== 验证数据集路径 ===")
train_paths = check_paths(train_paths, "训练集")
val_paths = check_paths(val_paths, "验证集")
test_paths = check_paths(test_paths, "测试集")
# 检查每个数据集
all_datasets = {
"训练集": train_paths,
"验证集": val_paths,
"测试集": test_paths
}
# 统计和验证每个数据集
for dataset_name, paths in all_datasets.items():
if not paths:
continue
print(f"\n=== 分析 {dataset_name} ===")
total_images = 0
total_labels = 0
missing_images = 0
missing_labels = 0
for path in paths:
print(f"\n- 处理路径: {path}")
# 构建对应的标签路径(默认假设标签在同级目录的labels文件夹下)
label_base = path.replace('images', 'labels')
# 统计图像
image_files = set()
for ext in ['jpg', 'jpeg', 'png', 'bmp', 'webp']:
image_files.update(Path(path).rglob(f'*.{ext}'))
# 统计标签
label_files = set()
for file in Path(label_base).rglob('*.txt'):
label_files.add(file)
# 计算图像和标签的基数(文件名,不包含扩展名)
image_bases = {img.stem for img in image_files}
label_bases = {lbl.stem for lbl in label_files}
# 检查缺失的图像或标签
missing_img = label_bases - image_bases
missing_lbl = image_bases - label_bases
# 输出统计信息
print(f" 图像数量: {len(image_files)}")
print(f" 标签数量: {len(label_files)}")
if missing_img:
print(f" 警告: 发现 {len(missing_img)} 个标签没有对应的图像")
# 打印所有未配对的标签名
for m in missing_img:
print(f" - {m}.txt")
missing_images += len(missing_img)
if missing_lbl:
print(f" 错误: 发现 {len(missing_lbl)} 个图像没有对应的标签")
# 打印所有未配对的图像名
for m in missing_lbl:
print(f" - {m}.*")
missing_labels += len(missing_lbl)
total_images += len(image_files)
total_labels += len(label_files)
# 输出该数据集的总体统计
print(f"\n--- {dataset_name} 总体统计 ---")
print(f" 总图像数: {total_images}")
print(f" 总标签数: {total_labels}")
print(f" 缺失图像数: {missing_images}")
print(f" 缺失标签数: {missing_labels}")
if missing_labels > 0:
print(f" ⚠️ 警告: {dataset_name} 存在缺失标签,可能影响训练效果")
elif missing_images > 0:
print(f" ⚠️ 警告: {dataset_name} 存在缺失图像,标签文件可能无效")
else:
print(f" ✓ {dataset_name} 所有图像和标签匹配正常")
# 检查类别配置
nc = config.get('nc')
names = config.get('names')
if nc is not None and names is not None:
if isinstance(names, list):
if len(names) != nc:
print(f"⚠️ 警告: 类别数量(nc={nc})与类别名称数量(len(names)={len(names)})不匹配")
else:
print(f"\n✓ 类别配置验证通过: nc={nc}, 类别名称: {', '.join(names)}")
else:
print(f"⚠️ 警告: 类别名称配置格式不正确,应为列表")
else:
print(f"⚠️ 警告: 配置文件中缺少 nc 或 names 字段")
if __name__ == "__main__":
# 通过变量指定配置文件路径
config_path = "/mnt/Virgil/YOLO/drone_dataset/drone_dataset.yaml" # 替换为你的配置文件路径
# 调用验证函数
validate_yolo_data_config(config_path)
802

被折叠的 条评论
为什么被折叠?



