import cv2
import os
import json
import shutil
import random
import numpy as np
from tqdm import tqdm
from pathlib import Path
from ultralytics import YOLO
from sklearn.model_selection import train_test_split
# ================== 配置参数 ==================
VIDEO_PATH = r"E:\专业综合实践-光电镊\参考资料\参考资料\41378_2025_892_MOESM3_ESM.mp4" # 替换为你的视频文件路径
OUTPUT_DIR = r"E:\专业综合实践-光电镊\参考资料\参考资料\41378_2025_892_MOESM3_ESM.dataset" # 数据集输出目录
BBOX_SIZE = 5 # 黑点边界框大小
FRAME_INTERVAL = 10 # 帧采样间隔
TRAIN_SPLIT = 0.8 # 训练集比例
# 创建目录结构
os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "labels"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "train/images"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "train/labels"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "val/images"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "val/labels"), exist_ok=True)
# ================== 1. 视频抽帧 ==================
def extract_frames(video_path, output_dir, frame_interval=10):
"""从视频中提取帧"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"无法打开视频文件: {video_path}")
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"视频信息: {frame_count}帧, {fps:.1f}FPS, {width}x{height}分辨率")
frame_paths = []
frame_index = 0
for i in tqdm(range(frame_count), desc="提取帧"):
ret, frame = cap.read()
if not ret:
break
if i % frame_interval == 0:
frame_path = os.path.join(output_dir, "images", f"frame_{frame_index:06d}.jpg")
cv2.imwrite(frame_path, frame)
frame_paths.append(frame_path)
frame_index += 1
cap.release()
print(f"成功提取 {frame_index} 帧图像")
return frame_paths
# ================== 2. 标注工具 ==================
def annotate_frames(frame_paths):
"""改进的交互式标注工具"""
annotations = {}
window_name = "标注工具 [Z:撤销 | C:清除 | B:上一帧 | S:保存 | Q:退出]"
# 创建全局变量存储当前状态
current_index = 0
total_frames = len(frame_paths)
redraw_needed = True # 重绘标志
# 鼠标回调函数(使用闭包捕获状态)
def click_event(event, x, y, flags, param):
nonlocal current_index, redraw_needed
path = frame_paths[current_index]
if path not in annotations:
annotations[path] = []
points = annotations[path]
if event == cv2.EVENT_LBUTTONDOWN:
points.append((x, y))
redraw_needed = True # 标记需要重绘
# 创建单个窗口实例
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
cv2.setMouseCallback(window_name, click_event)
while 0 <= current_index < total_frames:
path = frame_paths[current_index]
if path not in annotations:
annotations[path] = []
points = annotations[path]
# 只在需要时刷新图像
if redraw_needed:
img = cv2.imread(path)
if img is None:
print(f"警告: 无法读取图像 {path}")
current_index += 1
continue
display_img = img.copy()
# 绘制所有标注点
for point in points:
cv2.circle(display_img, point, 3, (0, 0, 255), -1)
# 显示帧信息
status_text = f"Frame {current_index + 1}/{total_frames} - 标注点: {len(points)}"
cv2.putText(display_img, status_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow(window_name, display_img)
redraw_needed = False # 重置重绘标志
key = cv2.waitKey(10) # 短时间等待,提高响应性
if key != -1: # 有按键按下
key = key & 0xFF
if key == ord('z') and points: # 撤销
points.pop()
redraw_needed = True
elif key == ord('c'): # 清除
annotations[path] = []
redraw_needed = True
elif key == ord('b'): # 上一帧
current_index = max(0, current_index - 1)
redraw_needed = True
elif key in [ord('n'), ord('s')]: # 下一帧或保存
current_index = min(total_frames - 1, current_index + 1)
redraw_needed = True
elif key == ord('q'): # 退出
break
cv2.destroyAllWindows()
return annotations
# ================== 3. 转换为YOLO格式 ==================
def save_yolo_annotations(annotations, output_dir, bbox_size):
"""将标注保存为YOLO格式"""
for img_path, points in annotations.items():
# 获取图像尺寸
img = cv2.imread(img_path)
height, width = img.shape[:2]
# 创建对应的标签文件路径
img_name = os.path.basename(img_path)
label_name = os.path.splitext(img_name)[0] + ".txt"
label_path = os.path.join(output_dir, "labels", label_name)
with open(label_path, 'w') as f:
for (x, y) in points:
# 计算边界框 (YOLO格式: class_id center_x center_y width height)
bbox = [
max(0, x - bbox_size // 2),
max(0, y - bbox_size // 2),
min(width, x + bbox_size // 2),
min(height, y + bbox_size // 2)
]
# 转换为归一化坐标
center_x = (bbox[0] + bbox[2]) / 2 / width
center_y = (bbox[1] + bbox[3]) / 2 / height
box_width = (bbox[2] - bbox[0]) / width
box_height = (bbox[3] - bbox[1]) / height
# 写入标签 (类别0: 黑点)
f.write(f"0 {center_x:.6f} {center_y:.6f} {box_width:.6f} {box_height:.6f}\n")
# ================== 4. 划分数据集 ==================
def split_dataset(output_dir, train_split=0.8):
"""划分训练集和验证集"""
all_images = [f for f in os.listdir(os.path.join(output_dir, "images")) if f.endswith('.jpg')]
train_images, val_images = train_test_split(all_images, train_size=train_split, random_state=42)
print(f"数据集划分: {len(train_images)}训练, {len(val_images)}验证")
# 复制图像和标签到训练/验证目录
def copy_files(file_list, subset):
for img_file in file_list:
# 复制图像
src_img = os.path.join(output_dir, "images", img_file)
dst_img = os.path.join(output_dir, subset, "images", img_file)
shutil.copy(src_img, dst_img)
# 复制标签
label_file = os.path.splitext(img_file)[0] + ".txt"
src_label = os.path.join(output_dir, "labels", label_file)
dst_label = os.path.join(output_dir, subset, "labels", label_file)
if os.path.exists(src_label):
shutil.copy(src_label, dst_label)
copy_files(train_images, "train")
copy_files(val_images, "val")
# 创建YOLO数据集配置文件
data_yaml = f"""
path: {os.path.abspath(output_dir)}
train: train/images
val: val/images
# 类别数量
nc: 1
# 类别名称
names: ['black_dot']
"""
yaml_path = os.path.join(output_dir, "data.yaml")
with open(yaml_path, 'w') as f:
f.write(data_yaml)
return yaml_path
# ================== 5. 训练YOLOv8模型 ==================
def train_yolov8(yaml_path, model_name="yolov8n.pt", epochs=100):
"""使用YOLOv8训练检测模型"""
# 加载模型
model = YOLO(model_name)
# 训练参数配置
train_args = {
'data': yaml_path,
'epochs': epochs,
'imgsz': 640,
'batch': 16,
'name': 'black_dot_detector',
'patience': 20, # 早停轮数
'device': 0 if torch.cuda.is_available() else 'cpu'
}
# 启动训练
results = model.train(**train_args)
# 保存最佳模型
best_model_path = results.save_dir / 'weights' / 'best.pt'
print(f"训练完成! 最佳模型保存于: {best_model_path}")
return best_model_path
# ================== 6. 模型验证 ==================
def validate_model(model_path, data_yaml):
"""验证模型性能"""
model = YOLO(model_path)
# 验证参数
metrics = model.val(
data=data_yaml,
batch=16,
imgsz=640,
conf=0.25,
iou=0.45,
device=0 if torch.cuda.is_available() else 'cpu'
)
print(f"验证结果:")
print(f"mAP@0.5: {metrics.box.map50:.4f}")
print(f"mAP@0.5:0.95: {metrics.box.map:.4f}")
print(f"精确率: {metrics.box.precision:.4f}")
print(f"召回率: {metrics.box.recall:.4f}")
return metrics
# ================== 主执行流程 ==================
if __name__ == "__main__":
# 步骤1: 从视频中提取帧
frame_paths = extract_frames(VIDEO_PATH, OUTPUT_DIR, FRAME_INTERVAL)
# 步骤2: 标注提取的帧
print("\n开始标注...")
annotations = annotate_frames(frame_paths)
print(f"标注完成! 共标注 {len(annotations)} 帧图像")
# 步骤3: 转换为YOLO格式
save_yolo_annotations(annotations, OUTPUT_DIR, BBOX_SIZE)
print("YOLO格式标注已保存")
# 步骤4: 划分训练集和验证集
yaml_path = split_dataset(OUTPUT_DIR, TRAIN_SPLIT)
print(f"数据集配置文件已创建: {yaml_path}")
# 步骤5: 训练YOLOv8模型
print("\n开始训练YOLOv8模型...")
import torch # 延迟导入以检查CUDA可用性
device = "GPU" if torch.cuda.is_available() else "CPU"
print(f"使用设备: {device}")
model_path = train_yolov8(yaml_path, epochs=100)
# 步骤6: 验证模型性能
print("\n验证模型性能...")
validate_model(model_path, yaml_path)
print("\n====== 流程完成 ======")
改错