关于os.makedir(path, exist_ok=true)

文章讨论了在处理文件或目录路径时的规则,即如果路径不存在则创建,但不会覆盖已存在的路径,保持原有内容不变。

如果path没有存在,会被创建;

如果path已经存在的话,不会覆盖既存的path,什么都不会做。

import cv2 import os import json import shutil import random import numpy as np from tqdm import tqdm from pathlib import Path from ultralytics import YOLO from sklearn.model_selection import train_test_split # ================== 配置参数 ================== VIDEO_PATH = r"E:\专业综合实践-光电镊\参考资料\参考资料\41378_2025_892_MOESM3_ESM.mp4" # 替换为你的视频文件路径 OUTPUT_DIR = r"E:\专业综合实践-光电镊\参考资料\参考资料\41378_2025_892_MOESM3_ESM.dataset" # 数据集输出目录 BBOX_SIZE = 5 # 黑点边界框大小 FRAME_INTERVAL = 10 # 帧采样间隔 TRAIN_SPLIT = 0.8 # 训练集比例 # 创建目录结构 os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, "labels"), exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, "train/images"), exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, "train/labels"), exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, "val/images"), exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, "val/labels"), exist_ok=True) # ================== 1. 视频抽帧 ================== def extract_frames(video_path, output_dir, frame_interval=10): """从视频中提取帧""" cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"无法打开视频文件: {video_path}") frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) print(f"视频信息: {frame_count}帧, {fps:.1f}FPS, {width}x{height}分辨率") frame_paths = [] frame_index = 0 for i in tqdm(range(frame_count), desc="提取帧"): ret, frame = cap.read() if not ret: break if i % frame_interval == 0: frame_path = os.path.join(output_dir, "images", f"frame_{frame_index:06d}.jpg") cv2.imwrite(frame_path, frame) frame_paths.append(frame_path) frame_index += 1 cap.release() print(f"成功提取 {frame_index} 帧图像") return frame_paths # ================== 2. 标注工具 ================== def annotate_frames(frame_paths): """改进的交互式标注工具""" annotations = {} window_name = "标注工具 [Z:撤销 | C:清除 | B:上一帧 | S:保存 | Q:退出]" # 创建全局变量存储当前状态 current_index = 0 total_frames = len(frame_paths) redraw_needed = True # 重绘标志 # 鼠标回调函数(使用闭包捕获状态) def click_event(event, x, y, flags, param): nonlocal current_index, redraw_needed path = frame_paths[current_index] if path not in annotations: annotations[path] = [] points = annotations[path] if event == cv2.EVENT_LBUTTONDOWN: points.append((x, y)) redraw_needed = True # 标记需要重绘 # 创建单个窗口实例 cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) cv2.setMouseCallback(window_name, click_event) while 0 <= current_index < total_frames: path = frame_paths[current_index] if path not in annotations: annotations[path] = [] points = annotations[path] # 只在需要时刷新图像 if redraw_needed: img = cv2.imread(path) if img is None: print(f"警告: 无法读取图像 {path}") current_index += 1 continue display_img = img.copy() # 绘制所有标注点 for point in points: cv2.circle(display_img, point, 3, (0, 0, 255), -1) # 显示帧信息 status_text = f"Frame {current_index + 1}/{total_frames} - 标注点: {len(points)}" cv2.putText(display_img, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.imshow(window_name, display_img) redraw_needed = False # 重置重绘标志 key = cv2.waitKey(10) # 短时间等待,提高响应性 if key != -1: # 有按键按下 key = key & 0xFF if key == ord('z') and points: # 撤销 points.pop() redraw_needed = True elif key == ord('c'): # 清除 annotations[path] = [] redraw_needed = True elif key == ord('b'): # 上一帧 current_index = max(0, current_index - 1) redraw_needed = True elif key in [ord('n'), ord('s')]: # 下一帧或保存 current_index = min(total_frames - 1, current_index + 1) redraw_needed = True elif key == ord('q'): # 退出 break cv2.destroyAllWindows() return annotations # ================== 3. 转换为YOLO格式 ================== def save_yolo_annotations(annotations, output_dir, bbox_size): """将标注保存为YOLO格式""" for img_path, points in annotations.items(): # 获取图像尺寸 img = cv2.imread(img_path) height, width = img.shape[:2] # 创建对应的标签文件路径 img_name = os.path.basename(img_path) label_name = os.path.splitext(img_name)[0] + ".txt" label_path = os.path.join(output_dir, "labels", label_name) with open(label_path, 'w') as f: for (x, y) in points: # 计算边界框 (YOLO格式: class_id center_x center_y width height) bbox = [ max(0, x - bbox_size // 2), max(0, y - bbox_size // 2), min(width, x + bbox_size // 2), min(height, y + bbox_size // 2) ] # 转换为归一化坐标 center_x = (bbox[0] + bbox[2]) / 2 / width center_y = (bbox[1] + bbox[3]) / 2 / height box_width = (bbox[2] - bbox[0]) / width box_height = (bbox[3] - bbox[1]) / height # 写入标签 (类别0: 黑点) f.write(f"0 {center_x:.6f} {center_y:.6f} {box_width:.6f} {box_height:.6f}\n") # ================== 4. 划分数据集 ================== def split_dataset(output_dir, train_split=0.8): """划分训练集和验证集""" all_images = [f for f in os.listdir(os.path.join(output_dir, "images")) if f.endswith('.jpg')] train_images, val_images = train_test_split(all_images, train_size=train_split, random_state=42) print(f"数据集划分: {len(train_images)}训练, {len(val_images)}验证") # 复制图像和标签到训练/验证目录 def copy_files(file_list, subset): for img_file in file_list: # 复制图像 src_img = os.path.join(output_dir, "images", img_file) dst_img = os.path.join(output_dir, subset, "images", img_file) shutil.copy(src_img, dst_img) # 复制标签 label_file = os.path.splitext(img_file)[0] + ".txt" src_label = os.path.join(output_dir, "labels", label_file) dst_label = os.path.join(output_dir, subset, "labels", label_file) if os.path.exists(src_label): shutil.copy(src_label, dst_label) copy_files(train_images, "train") copy_files(val_images, "val") # 创建YOLO数据集配置文件 data_yaml = f""" path: {os.path.abspath(output_dir)} train: train/images val: val/images # 类别数量 nc: 1 # 类别名称 names: ['black_dot'] """ yaml_path = os.path.join(output_dir, "data.yaml") with open(yaml_path, 'w') as f: f.write(data_yaml) return yaml_path # ================== 5. 训练YOLOv8模型 ================== def train_yolov8(yaml_path, model_name="yolov8n.pt", epochs=100): """使用YOLOv8训练检测模型""" # 加载模型 model = YOLO(model_name) # 训练参数配置 train_args = { 'data': yaml_path, 'epochs': epochs, 'imgsz': 640, 'batch': 16, 'name': 'black_dot_detector', 'patience': 20, # 早停轮数 'device': 0 if torch.cuda.is_available() else 'cpu' } # 启动训练 results = model.train(**train_args) # 保存最佳模型 best_model_path = results.save_dir / 'weights' / 'best.pt' print(f"训练完成! 最佳模型保存于: {best_model_path}") return best_model_path # ================== 6. 模型验证 ================== def validate_model(model_path, data_yaml): """验证模型性能""" model = YOLO(model_path) # 验证参数 metrics = model.val( data=data_yaml, batch=16, imgsz=640, conf=0.25, iou=0.45, device=0 if torch.cuda.is_available() else 'cpu' ) print(f"验证结果:") print(f"mAP@0.5: {metrics.box.map50:.4f}") print(f"mAP@0.5:0.95: {metrics.box.map:.4f}") print(f"精确率: {metrics.box.precision:.4f}") print(f"召回率: {metrics.box.recall:.4f}") return metrics # ================== 主执行流程 ================== if __name__ == "__main__": # 步骤1: 从视频中提取帧 frame_paths = extract_frames(VIDEO_PATH, OUTPUT_DIR, FRAME_INTERVAL) # 步骤2: 标注提取的帧 print("\n开始标注...") annotations = annotate_frames(frame_paths) print(f"标注完成! 共标注 {len(annotations)} 帧图像") # 步骤3: 转换为YOLO格式 save_yolo_annotations(annotations, OUTPUT_DIR, BBOX_SIZE) print("YOLO格式标注已保存") # 步骤4: 划分训练集和验证集 yaml_path = split_dataset(OUTPUT_DIR, TRAIN_SPLIT) print(f"数据集配置文件已创建: {yaml_path}") # 步骤5: 训练YOLOv8模型 print("\n开始训练YOLOv8模型...") import torch # 延迟导入以检查CUDA可用性 device = "GPU" if torch.cuda.is_available() else "CPU" print(f"使用设备: {device}") model_path = train_yolov8(yaml_path, epochs=100) # 步骤6: 验证模型性能 print("\n验证模型性能...") validate_model(model_path, yaml_path) print("\n====== 流程完成 ======") 改错
10-12
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值