import cv2
import numpy as np
import os
import random
from pathlib import Path
from tqdm import tqdm
import yaml
from ultralytics import YOLO
import shutil # 添加用于文件操作的库
class DotDatasetGenerator:
def __init__(self, output_dir="black_dot_dataset", image_size=640):
self.output_dir = output_dir
self.image_size = image_size
# 创建目录结构
self.train_img_dir = os.path.join(output_dir, "images", "train")
self.val_img_dir = os.path.join(output_dir, "images", "val")
self.train_label_dir = os.path.join(output_dir, "labels", "train")
self.val_label_dir = os.path.join(output_dir, "labels", "val")
Path(self.train_img_dir).mkdir(parents=True, exist_ok=True)
Path(self.val_img_dir).mkdir(parents=True, exist_ok=True)
Path(self.train_label_dir).mkdir(parents=True, exist_ok=True)
Path(self.val_label_dir).mkdir(parents=True, exist_ok=True)
# 颜色定义
self.gray_bg = (128, 128, 128)
self.black = (0, 0, 0)
self.white = (255, 255, 255)
# 配置参数
self.min_dots = 10
self.max_dots = 30
self.min_rings = 1
self.max_rings = 3
self.min_obstacles = 2
self.max_obstacles = 5
def _create_base_image(self):
return np.full((self.image_size, self.image_size, 3), self.gray_bg, dtype=np.uint8)
def _add_blur(self, img, intensity=3):
return cv2.GaussianBlur(img, (intensity*2+1, intensity*2+1), intensity)
def _add_random_ring(self, img):
radius = random.randint(80, 150)
thickness = random.randint(8, 15)
center_x = random.randint(radius + thickness, self.image_size - radius - thickness)
center_y = random.randint(radius + thickness, self.image_size - radius - thickness)
cv2.circle(img, (center_x, center_y), radius, self.white, thickness)
return center_x, center_y, radius
def _add_random_obstacle(self, img):
sides = random.choice([3, 4, 5])
points = []
start_x = random.randint(50, self.image_size - 150)
start_y = random.randint(50, self.image_size - 150)
for _ in range(sides):
points.append([start_x + random.randint(0, 100), start_y + random.randint(0, 100)])
pts = np.array(points, np.int32).reshape((-1, 1, 2))
thickness = random.randint(3, 8)
cv2.polylines(img, [pts], True, self.black, thickness)
return cv2.boundingRect(pts)
def _is_near_ring(self, x, y, ring_x, ring_y, ring_radius):
distance = np.sqrt((x - ring_x)**2 + (y - ring_y)**2)
return ring_radius * 0.9 <= distance <= ring_radius * 1.1
def _is_in_obstacle(self, x, y, obstacles):
for obstacle in obstacles:
x1, y1, w, h = obstacle
if (x1 - 5 <= x <= x1 + w + 5) and (y1 - 5 <= y <= y1 + h + 5):
return True
return False
def _add_black_dots(self, img, rings, obstacles):
num_dots = random.randint(self.min_dots, self.max_dots)
dot_radius = random.randint(2, 5)
dots = []
ring_dots = min(int(num_dots * 0.3), len(rings))
for i in range(num_dots):
x = random.randint(dot_radius + 5, self.image_size - dot_radius - 5)
y = random.randint(dot_radius + 5, self.image_size - dot_radius - 5)
# 避免点过于接近
if any(np.sqrt((x - dx)**2 + (y - dy)**2) < dot_radius * 3 for dx, dy, _ in dots):
continue
# 部分点在圆环上
if i < ring_dots and rings:
ring = random.choice(rings)
angle = random.uniform(0, 2 * np.pi)
x = int(ring[0] + ring[2] * np.cos(angle))
y = int(ring[1] + ring[2] * np.sin(angle))
# 部分点在障碍物边界上
elif i > num_dots * 0.8 and random.random() < 0.5 and obstacles:
obstacle = random.choice(obstacles)
ox, oy, ow, oh = obstacle
side = random.choice(['top', 'bottom', 'left', 'right'])
if side == 'top': x, y = random.randint(ox, ox + ow), oy
elif side == 'bottom': x, y = random.randint(ox, ox + ow), oy + oh
elif side == 'left': x, y = ox, random.randint(oy, oy + oh)
else: x, y = ox + ow, random.randint(oy, oy + oh)
cv2.circle(img, (x, y), dot_radius, self.black, -1)
dots.append((x, y, dot_radius))
return dots
def generate_image(self, save_path, label_path):
img = self._create_base_image()
rings = [self._add_random_ring(img) for _ in range(random.randint(self.min_rings, self.max_rings))]
obstacles = [self._add_random_obstacle(img) for _ in range(random.randint(self.min_obstacles, self.max_obstacles))]
dots = self._add_black_dots(img, rings, obstacles)
img = self._add_blur(img, intensity=random.randint(2, 4))
cv2.imwrite(save_path, img)
with open(label_path, 'w') as f:
for x, y, r in dots:
f.write(f"0 {x/self.image_size:.6f} {y/self.image_size:.6f} {(r*2)/self.image_size:.6f} {(r*2)/self.image_size:.6f}\n")
return img
def generate_dataset(self, num_train=500, num_val=100):
print(f"生成训练集 ({num_train}张图片)...")
for i in tqdm(range(num_train)):
self.generate_image(os.path.join(self.train_img_dir, f"train_{i}.jpg"),
os.path.join(self.train_label_dir, f"train_{i}.txt"))
print(f"生成验证集 ({num_val}张图片)...")
for i in tqdm(range(num_val)):
self.generate_image(os.path.join(self.val_img_dir, f"val_{i}.jpg"),
os.path.join(self.val_label_dir, f"val_{i}.txt"))
return self.create_yaml_config()
def create_yaml_config(self):
config = {
'path': os.path.abspath(self.output_dir),
'train': 'images/train',
'val': 'images/val',
'names': {0: 'black_dot'}
}
config_path = os.path.join(self.output_dir, "black_dot.yaml")
with open(config_path, 'w') as f:
yaml.dump(config, f)
print(f"数据集配置文件已保存至: {config_path}")
return config_path
def train_yolov8_model(dataset_config, model_size='s'):
"""
训练YOLOv8模型检测黑色小点,仅保留best.pt和last.pt
:param dataset_config: 数据集配置文件路径
:param model_size: 模型尺寸 (n/s/m/l/x)
:return: (best.pt路径, last.pt路径)
"""
# 加载预训练模型
model = YOLO(f'yolov8{model_size}.pt')
# 训练参数 (针对小目标优化)
results = model.train(
data=dataset_config,
epochs=500,
imgsz=640,
batch=4,
device='cuda' if torch.cuda.is_available() else 'cpu',
name='black_dot_detection',
# 小目标优化参数
workers=2, # 减少数据加载工作线程数
close_mosaic=15, # 提前关闭mosaic增强
augment=True, # 保留基本增强但降低强度
fliplr=0.0,
overlap_mask=False,
# 学习率与优化器
lr0=0.01,
lrf=0.1,
optimizer='AdamW',
weight_decay=0.0005,
# 数据增强增强小目标可见性
hsv_h=0.015,
hsv_s=0.5,
hsv_v=0.3,
translate=0.1,
scale=0.9,
# 多尺度训练
multi_scale=True,
# 日志设置
plots=True,
save_period=0, # 设置为0表示不保存中间检查点
# 混合精度训练优化
amp=True,
patience=500, # 延长早停观察期
)
# 获取最终模型路径
weights_dir = os.path.join(results.save_dir, 'weights')
best_pt = os.path.join(weights_dir, 'best.pt')
last_pt = os.path.join(weights_dir, 'last.pt')
# 创建最终保存目录
final_save_dir = os.path.join(results.save_dir, 'final_weights')
os.makedirs(final_save_dir, exist_ok=True)
# 只复制best.pt和last.pt到最终目录
final_best_pt = os.path.join(final_save_dir, 'best.pt')
final_last_pt = os.path.join(final_save_dir, 'last.pt')
shutil.copy2(best_pt, final_best_pt)
shutil.copy2(last_pt, final_last_pt)
# 删除原始权重目录(可选)
try:
shutil.rmtree(weights_dir)
print(f"已删除中间权重目录: {weights_dir}")
except Exception as e:
print(f"删除中间权重目录失败: {str(e)}")
return final_best_pt, final_last_pt
# 主执行流程
if __name__ == "__main__":
import torch
print("="*50)
print("生成数据集...")
print("="*50)
# 生成数据集
generator = DotDatasetGenerator(output_dir="black_dot_dataset")
dataset_config = generator.generate_dataset(num_train=500, num_val=100)
print("\n" + "="*50)
print("开始训练YOLOv8模型...")
print("="*50)
# 训练模型
model_size = 's' # 可选: 'n'(小), 's'(小), 'm'(中), 'l'(大), 'x'(超大)
best_pt_path, last_pt_path = train_yolov8_model(dataset_config, model_size)
print("\n" + "="*50)
print("训练完成!")
print(f"最佳模型已保存至: {best_pt_path}")
print(f"最终检查点已保存至: {last_pt_path}")
print("="*50)
# 验证模型性能
print("\n在验证集上测试模型性能...")
model = YOLO(best_pt_path)
metrics = model.val(data=dataset_config)
print("\n" + "="*50)
print("模型评估结果:")
print(f"精确率(mAP@0.5): {metrics.box.map50:.4f}")
print(f"召回率: {metrics.box.recall:.4f}")
print("="*50)
想进一步提高精度,可以怎么改,显存只有4GB