YOLO训练脚本汇总

一、环境配置与检查

CUDA

检查CUDA版本:
nvcc -V

查看CUDA安装位置
where nvcc

安装地址:https://developer.nvidia.com/cuda-toolkit-archive

cuDNN

https://developer.nvidia.com/rdp/cudnn-archive

cuDNN解压后有下面三个文件夹
在这里插入代码片在这里插入图片描述

  • 把 bin 目录的内容拷贝到:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin

  • 把 include 目录的内容拷贝到:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\include

  • 把 lib\x64 目录的内容拷贝到:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\lib\x64

conda

# 查看版本
conda --version
pip --version

# 查看,新建,删除环境
conda env list
conda create --name envname python=x.x
conda env remove --name envname

# 查看包
pip list
conda list
pip install package_name==x.x
pip uninstall package_name

conda activate envname


换清华园

# 添加清华镜像源
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
# 设置搜索时显示通道地址
conda config --set show_channel_urls yes

# 检查是否成功
conda config --show channels

# 换回默认园
conda config --remove-key channels  

check env

import sys
import platform
import subprocess
import importlib
import torch
import numpy as np


def check_python_version():
    """检查Python版本"""
    print("=== Python 环境 ===")
    print(f"Python 版本: {
     
     sys.version.split()[0]}")
    required = (3, 8)
    current = sys.version_info[:2]
    if current < required:
        print(f"警告: 推荐使用 Python {
     
     required[0]}.{
     
     required[1]} 或更高版本")
    print()


def check_os_info():
    """检查操作系统信息"""
    print("=== 操作系统信息 ===")
    print(f"系统: {
     
     platform.system()} {
     
     platform.release()}")
    print(f"架构: {
     
     platform.machine()}")
    print(f"处理器: {
     
     platform.processor()}")
    print()


def check_gpu():
    """检查GPU和CUDA环境"""
    print("=== GPU 环境 ===")
    if torch.cuda.is_available():
        print(f"CUDA 可用: 是")
        print(f"CUDA 版本: {
     
     torch.version.cuda}")
        print(f"GPU 数量: {
     
     torch.cuda.device_count()}")

        for i in range(torch.cuda.device_count()):
            print(f"GPU {
     
     i}: {
     
     torch.cuda.get_device_name(i)}")
            print(f"  显存: {
     
     torch.cuda.get_device_properties(i).total_memory / 1024 ** 3:.2f} GB")

        # 检查cuDNN
        print(f"cuDNN 启用: {
     
     '是' if torch.backends.cudnn.enabled else '否'}")
        if torch.backends.cudnn.enabled:
            print(f"cuDNN 版本: {
     
     torch.backends.cudnn.version()}")
    else:
        print("CUDA 可用: 否")
        print("警告: 未检测到可用的NVIDIA GPU和CUDA环境,训练将非常缓慢")
    print()


def check_required_libraries():
    """检查所需的Python库"""
    print("=== 依赖库检查 ===")
    libraries = {
   
   
        'torch': 'PyTorch',
        'torchvision': 'TorchVision',
        'numpy': 'NumPy',
        'opencv-python': 'OpenCV',
        'matplotlib': 'Matplotlib',
        'PIL': 'Pillow',
        'PyYAML': 'PyYAML',
        'tqdm': 'tqdm',
        'tensorboard': 'TensorBoard'
    }

    for lib, name in libraries.items():
        try:
            module = importlib.import_module(lib)
            version = getattr(module, '__version__', '未知版本')
            print(f"{
     
     name}: 已安装 (版本: {
     
     version})")
        except ImportError:
            print(f"警告: {
     
     name} 未安装")
    print()


def check_disk_space():
    """检查磁盘空间(仅Linux/macOS)"""
    print("=== 磁盘空间检查 ===")
    if platform.system() in ['Linux', 'Darwin']:
        try:
            result = subprocess.check_output(['df', '-h', '.']).decode().splitlines()
            if len(result) >= 2:
                print(f"当前目录所在磁盘: {
     
     result[1]}")
        except Exception as e:
            print(f"检查磁盘空间时出错: {
     
     e}")
    else:
        print("磁盘空间检查仅支持Linux和macOS系统")
    print()


def check_memory():
    """检查系统内存(简化版)"""
    print("=== 内存检查 ===")
    if platform.system() == 'Linux':
        try:
            with open('/proc/meminfo', 'r') as f:
                mem_total = f.readline()
                mem_available = f.readline()
            print(f"总内存: {
     
     mem_total.strip()}")
            print(f"可用内存: {
     
     mem_available.strip()}")
        except Exception as e:
            print(f"检查内存时出错: {
     
     e}")
    else:
        print("详细内存信息检查仅支持Linux系统")
    print()


def main():
    print("=" * 50)
    print("         YOLO 训练环境检查工具         ")
    print("=" * 50)
    print()

    check_python_version()
    check_os_info()
    check_gpu()
    check_memory()
    check_disk_space()
    check_required_libraries()

    print("=" * 50)
    print("环境检查完成")
    print("注意: 警告信息表示可能存在潜在问题,但不代表完全无法运行")


if __name__ == "__main__":
    main()


二、数据集检查与预处理

yolo格式数据集可视化检测脚本

  • 变量记录两个路径,图片路径和标签路径。
  • 随机抽取三个图片,并将标签的框可视化出来,可视化图片上方有对应的图片名。
    效果:
    在这里插入图片描述
import os
import random
import cv2
import matplotlib.pyplot as plt

# 设置中文字体,确保中文正常显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]

def visualize_yolo_dataset(image_dir, label_dir, num_samples=3):
    """
    可视化YOLO格式的目标检测数据集
    
    参数:
    image_dir (str): 图片文件夹路径
    label_dir (str): 标签文件夹路径
    num_samples (int): 随机抽取的样本数量
    """
    # 获取所有图片文件
    image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    # 如果没有图片,直接返回
    if not image_files:
        print("未找到图片文件")
        return
    
    # 随机选择指定数量的图片
    if len(image_files) < num_samples:
        num_samples = len(image_files)
    selected_images = random.sample(image_files, num_samples)
    
    # 创建子图
    fig, axes = plt.subplots(1, num_samples, figsize=(5 * num_samples, 5))
    if num_samples == 1:
        axes = [axes]
    
    # 遍历选中的图片
    for i, img_file in enumerate(selected_images):
        # 读取图片
        img_path = os.path.join(image_dir, img_file)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # 转换颜色空间
        
        # 获取对应的标签文件
        label_file = os.path.splitext(img_file)[0] + '.txt'
        label_path = os.path.join(label_dir, label_file)
        
        # 读取标签
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                lines = f.readlines()
            
            # 获取图像尺寸
            h, w, _ = img.shape
            
            # 绘制边界框
            for line in lines:
                parts = line.strip().split()
                if len(parts) < 5:  # 确保标签格式正确
                    continue
                
       
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值