【YOLO脚本】数据集yaml文件检查

check_dataset.py

使用方法:修改变量config_path = "/path/to/your/dataset.yaml"直接运行。

import os
import yaml
from pathlib import Path

def validate_yolo_data_config(config_path):
    """
    验证 YOLO 数据配置文件的完整性,检查图像和标签文件是否匹配
    
    Args:
        config_path (str): 数据配置文件(.yaml)的路径
    """
    # 读取配置文件
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
    except Exception as e:
        print(f"读取配置文件失败: {e}")
        return
    
    # 获取配置中的关键路径
    base_path = config.get('path', '')
    train_paths = config.get('train', [])
    val_paths = config.get('val', [])
    test_paths = config.get('test', [])  # 可选
    
    # 如果train/val是单个路径,转换为列表
    if isinstance(train_paths, str):
        train_paths = [train_paths]
    if isinstance(val_paths, str):
        val_paths = [val_paths]
    if isinstance(test_paths, str):
        test_paths = [test_paths]
    
    # 验证路径是否存在
    def check_paths(paths, dataset_type):
        if not paths:
            print(f"警告: {dataset_type} 路径为空")
            return []
        
        valid_paths = []
        for path in paths:
            full_path = os.path.join(base_path, path) if base_path else path
            if not os.path.exists(full_path):
                print(f"错误: {dataset_type} 路径不存在: {full_path}")
            else:
                valid_paths.append(full_path)
        return valid_paths
    
    print("\n=== 验证数据集路径 ===")
    train_paths = check_paths(train_paths, "训练集")
    val_paths = check_paths(val_paths, "验证集")
    test_paths = check_paths(test_paths, "测试集")
    
    # 检查每个数据集
    all_datasets = {
        "训练集": train_paths,
        "验证集": val_paths,
        "测试集": test_paths
    }
    
    # 统计和验证每个数据集
    for dataset_name, paths in all_datasets.items():
        if not paths:
            continue
            
        print(f"\n=== 分析 {dataset_name} ===")
        total_images = 0
        total_labels = 0
        missing_images = 0
        missing_labels = 0
        
        for path in paths:
            print(f"\n- 处理路径: {path}")
            
            # 构建对应的标签路径(默认假设标签在同级目录的labels文件夹下)
            label_base = path.replace('images', 'labels')
            
            # 统计图像
            image_files = set()
            for ext in ['jpg', 'jpeg', 'png', 'bmp', 'webp']:
                image_files.update(Path(path).rglob(f'*.{ext}'))
            
            # 统计标签
            label_files = set()
            for file in Path(label_base).rglob('*.txt'):
                label_files.add(file)
            
            # 计算图像和标签的基数(文件名,不包含扩展名)
            image_bases = {img.stem for img in image_files}
            label_bases = {lbl.stem for lbl in label_files}
            
            # 检查缺失的图像或标签
            missing_img = label_bases - image_bases
            missing_lbl = image_bases - label_bases
            
            # 输出统计信息
            print(f"  图像数量: {len(image_files)}")
            print(f"  标签数量: {len(label_files)}")
            
            if missing_img:
                print(f"  警告: 发现 {len(missing_img)} 个标签没有对应的图像")
                # 打印所有未配对的标签名
                for m in missing_img:
                    print(f"    - {m}.txt")
                missing_images += len(missing_img)
            
            if missing_lbl:
                print(f"  错误: 发现 {len(missing_lbl)} 个图像没有对应的标签")
                # 打印所有未配对的图像名
                for m in missing_lbl:
                    print(f"    - {m}.*")
                missing_labels += len(missing_lbl)
            
            total_images += len(image_files)
            total_labels += len(label_files)
        
        # 输出该数据集的总体统计
        print(f"\n--- {dataset_name} 总体统计 ---")
        print(f"  总图像数: {total_images}")
        print(f"  总标签数: {total_labels}")
        print(f"  缺失图像数: {missing_images}")
        print(f"  缺失标签数: {missing_labels}")
        
        if missing_labels > 0:
            print(f"  ⚠️ 警告: {dataset_name} 存在缺失标签,可能影响训练效果")
        elif missing_images > 0:
            print(f"  ⚠️ 警告: {dataset_name} 存在缺失图像,标签文件可能无效")
        else:
            print(f"  ✓ {dataset_name} 所有图像和标签匹配正常")
    
    # 检查类别配置
    nc = config.get('nc')
    names = config.get('names')
    if nc is not None and names is not None:
        if isinstance(names, list):
            if len(names) != nc:
                print(f"⚠️ 警告: 类别数量(nc={nc})与类别名称数量(len(names)={len(names)})不匹配")
            else:
                print(f"\n✓ 类别配置验证通过: nc={nc}, 类别名称: {', '.join(names)}")
        else:
            print(f"⚠️ 警告: 类别名称配置格式不正确,应为列表")
    else:
        print(f"⚠️ 警告: 配置文件中缺少 nc 或 names 字段")

if __name__ == "__main__":
    # 通过变量指定配置文件路径
    config_path = "/mnt/Virgil/YOLO/drone_dataset/drone_dataset.yaml"  # 替换为你的配置文件路径
    
    # 调用验证函数
    validate_yolo_data_config(config_path)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值