一、环境配置与检查
CUDA
检查CUDA版本:
nvcc -V
查看CUDA安装位置
where nvcc
安装地址:https://developer.nvidia.com/cuda-toolkit-archive
cuDNN
https://developer.nvidia.com/rdp/cudnn-archive
cuDNN解压后有下面三个文件夹
在这里插入代码片
-
把 bin 目录的内容拷贝到:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin -
把 include 目录的内容拷贝到:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\include -
把 lib\x64 目录的内容拷贝到:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\lib\x64
conda
# 查看版本
conda --version
pip --version
# 查看,新建,删除环境
conda env list
conda create --name envname python=x.x
conda env remove --name envname
# 查看包
pip list
conda list
pip install package_name==x.x
pip uninstall package_name
conda activate envname
换清华园
# 添加清华镜像源
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
# 设置搜索时显示通道地址
conda config --set show_channel_urls yes
# 检查是否成功
conda config --show channels
# 换回默认园
conda config --remove-key channels
check env
import sys
import platform
import subprocess
import importlib
import torch
import numpy as np
def check_python_version():
"""检查Python版本"""
print("=== Python 环境 ===")
print(f"Python 版本: {
sys.version.split()[0]}")
required = (3, 8)
current = sys.version_info[:2]
if current < required:
print(f"警告: 推荐使用 Python {
required[0]}.{
required[1]} 或更高版本")
print()
def check_os_info():
"""检查操作系统信息"""
print("=== 操作系统信息 ===")
print(f"系统: {
platform.system()} {
platform.release()}")
print(f"架构: {
platform.machine()}")
print(f"处理器: {
platform.processor()}")
print()
def check_gpu():
"""检查GPU和CUDA环境"""
print("=== GPU 环境 ===")
if torch.cuda.is_available():
print(f"CUDA 可用: 是")
print(f"CUDA 版本: {
torch.version.cuda}")
print(f"GPU 数量: {
torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"GPU {
i}: {
torch.cuda.get_device_name(i)}")
print(f" 显存: {
torch.cuda.get_device_properties(i).total_memory / 1024 ** 3:.2f} GB")
# 检查cuDNN
print(f"cuDNN 启用: {
'是' if torch.backends.cudnn.enabled else '否'}")
if torch.backends.cudnn.enabled:
print(f"cuDNN 版本: {
torch.backends.cudnn.version()}")
else:
print("CUDA 可用: 否")
print("警告: 未检测到可用的NVIDIA GPU和CUDA环境,训练将非常缓慢")
print()
def check_required_libraries():
"""检查所需的Python库"""
print("=== 依赖库检查 ===")
libraries = {
'torch': 'PyTorch',
'torchvision': 'TorchVision',
'numpy': 'NumPy',
'opencv-python': 'OpenCV',
'matplotlib': 'Matplotlib',
'PIL': 'Pillow',
'PyYAML': 'PyYAML',
'tqdm': 'tqdm',
'tensorboard': 'TensorBoard'
}
for lib, name in libraries.items():
try:
module = importlib.import_module(lib)
version = getattr(module, '__version__', '未知版本')
print(f"{
name}: 已安装 (版本: {
version})")
except ImportError:
print(f"警告: {
name} 未安装")
print()
def check_disk_space():
"""检查磁盘空间(仅Linux/macOS)"""
print("=== 磁盘空间检查 ===")
if platform.system() in ['Linux', 'Darwin']:
try:
result = subprocess.check_output(['df', '-h', '.']).decode().splitlines()
if len(result) >= 2:
print(f"当前目录所在磁盘: {
result[1]}")
except Exception as e:
print(f"检查磁盘空间时出错: {
e}")
else:
print("磁盘空间检查仅支持Linux和macOS系统")
print()
def check_memory():
"""检查系统内存(简化版)"""
print("=== 内存检查 ===")
if platform.system() == 'Linux':
try:
with open('/proc/meminfo', 'r') as f:
mem_total = f.readline()
mem_available = f.readline()
print(f"总内存: {
mem_total.strip()}")
print(f"可用内存: {
mem_available.strip()}")
except Exception as e:
print(f"检查内存时出错: {
e}")
else:
print("详细内存信息检查仅支持Linux系统")
print()
def main():
print("=" * 50)
print(" YOLO 训练环境检查工具 ")
print("=" * 50)
print()
check_python_version()
check_os_info()
check_gpu()
check_memory()
check_disk_space()
check_required_libraries()
print("=" * 50)
print("环境检查完成")
print("注意: 警告信息表示可能存在潜在问题,但不代表完全无法运行")
if __name__ == "__main__":
main()
二、数据集检查与预处理
yolo格式数据集可视化检测脚本
- 变量记录两个路径,图片路径和标签路径。
- 随机抽取三个图片,并将标签的框可视化出来,可视化图片上方有对应的图片名。
效果:

import os
import random
import cv2
import matplotlib.pyplot as plt
# 设置中文字体,确保中文正常显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
def visualize_yolo_dataset(image_dir, label_dir, num_samples=3):
"""
可视化YOLO格式的目标检测数据集
参数:
image_dir (str): 图片文件夹路径
label_dir (str): 标签文件夹路径
num_samples (int): 随机抽取的样本数量
"""
# 获取所有图片文件
image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
# 如果没有图片,直接返回
if not image_files:
print("未找到图片文件")
return
# 随机选择指定数量的图片
if len(image_files) < num_samples:
num_samples = len(image_files)
selected_images = random.sample(image_files, num_samples)
# 创建子图
fig, axes = plt.subplots(1, num_samples, figsize=(5 * num_samples, 5))
if num_samples == 1:
axes = [axes]
# 遍历选中的图片
for i, img_file in enumerate(selected_images):
# 读取图片
img_path = os.path.join(image_dir, img_file)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换颜色空间
# 获取对应的标签文件
label_file = os.path.splitext(img_file)[0] + '.txt'
label_path = os.path.join(label_dir, label_file)
# 读取标签
if os.path.exists(label_path):
with open(label_path, 'r') as f:
lines = f.readlines()
# 获取图像尺寸
h, w, _ = img.shape
# 绘制边界框
for line in lines:
parts = line.strip().split()
if len(parts) < 5: # 确保标签格式正确
continue

最低0.47元/天 解锁文章
5万+

被折叠的 条评论
为什么被折叠?



