from ultralytics import YOLO
import os
import yaml
import torch
import matplotlib.pyplot as plt
from IPython.display import Image
class YOLOv8CustomTrainer:
def __init__(self, data_config_path, model_size='n'):
"""
初始化YOLOv8训练器
Args:
data_config_path (str): 数据集配置文件路径
model_size (str): 模型大小 ('n', 's', 'm', 'l', 'x')
"""
self.data_config_path = data_config_path
self.model_size = model_size
self.model = None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {self.device}")
# 验证数据集配置
self._validate_data_config()
def _validate_data_config(self):
"""验证数据集配置文件"""
if not os.path.exists(self.data_config_path):
raise FileNotFoundError(f"数据集配置文件不存在: {self.data_config_path}")
with open(self.data_config_path, 'r') as f:
data_config = yaml.safe_load(f)
# 检查关键字段
required_keys = ['train', 'val', 'nc', 'names']
for key in required_keys:
if key not in data_config:
raise ValueError(f"数据集配置缺少必要字段: {key}")
# 检查路径是否存在
for path_key in ['train', 'val']:
path = data_config[path_key]
if not os.path.exists(path):
raise FileNotFoundError(f"数据集路径不存在: {path}")
print("✅ 数据集配置验证通过")
def load_model(self, pretrained=True):
"""
加载YOLOv8模型
Args:
pretrained (bool): 是否加载预训练权重
"""
model_name = f'yolov8{self.model_size}.pt'
print(f"加载模型: {model_name}")
if pretrained:
# 从官方预训练权重加载
self.model = YOLO(model_name)
else:
# 从头开始训练
self.model = YOLO(f'yolov8{self.model_size}.yaml').load(model_name)
return self.model
def train(self, epochs=100, imgsz=640, batch=16, **kwargs):
"""
训练YOLOv8模型
Args:
epochs (int): 训练轮数
imgsz (int): 输入图像尺寸
batch (int): 批次大小
**kwargs: 其他训练参数
Returns:
dict: 训练结果
"""
if self.model is None:
self.load_model()
# 优化训练参数(参考引用[1])
train_params = {
'data': self.data_config_path,
'epochs': epochs,
'imgsz': imgsz,
'batch': batch,
'device': self.device,
'optimizer': 'AdamW', # 推荐使用AdamW优化器
'lr0': 0.01, # 初始学习率
'lrf': 0.01, # 最终学习率 = lr0 * lrf
'momentum': 0.937,
'weight_decay': 0.0005,
'warmup_epochs': 3.0, # 预热轮数
'warmup_momentum': 0.8,
'warmup_bias_lr': 0.1,
'box': 7.5, # 框损失权重
'cls': 0.5, # 分类损失权重
'dfl': 1.5, # 分布焦点损失权重
'close_mosaic': 10, # 最后10轮关闭Mosaic增强
'amp': True, # 自动混合精度训练
'patience': 50, # 早停轮数
'save': True,
'save_period': 10, # 每10个epoch保存一次
'cache': True, # 使用RAM缓存加速训练
'single_cls': False, # 多类别训练
'cos_lr': True, # 余弦学习率调度
'overlap_mask': True,
'mask_ratio': 4,
'dropout': 0.0, # 分类器dropout概率
'name': 'custom_train',
'project': 'runs/detect',
**kwargs # 允许覆盖默认参数
}
print("开始训练模型...")
results = self.model.train(**train_params)
# 保存训练曲线
self._plot_training_curves(results)
print("✅ 训练完成!")
return results
def _plot_training_curves(self, results):
"""绘制训练曲线"""
try:
# 损失曲线
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(results.results['train/box_loss'], label='Box Loss')
plt.plot(results.results['train/cls_loss'], label='Cls Loss')
plt.plot(results.results['train/dfl_loss'], label='DFL Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# 验证指标
plt.subplot(1, 2, 2)
plt.plot(results.results['metrics/precision'], label='Precision')
plt.plot(results.results['metrics/recall'], label='Recall')
plt.plot(results.results['metrics/mAP50'], label='mAP50')
plt.plot(results.results['metrics/mAP50-95'], label='mAP50-95')
plt.title('Validation Metrics')
plt.xlabel('Epoch')
plt.ylabel('Value')
plt.legend()
plt.tight_layout()
plt.savefig('training_curves.png')
print("训练曲线已保存为 training_curves.png")
except Exception as e:
print(f"绘制训练曲线失败: {e}")
def validate(self, model_path=None, imgsz=640):
"""
验证模型性能
Args:
model_path (str): 模型路径
imgsz (int): 验证图像尺寸
Returns:
dict: 验证结果
"""
if model_path:
model = YOLO(model_path)
else:
model = self.model
if model is None:
raise ValueError("没有可用的模型进行验证")
print("开始模型验证...")
metrics = model.val(
data=self.data_config_path,
imgsz=imgsz,
batch=16,
conf=0.001, # 置信度阈值
iou=0.6, # IoU阈值
device=self.device,
split='val', # 使用验证集
plots=True, # 生成混淆矩阵等
save_json=True, # 保存JSON格式结果
save_hybrid=False,
half=False, # 使用FP32精度
rect=True, # 矩形验证
verbose=True
)
# 打印关键指标
print(f"验证结果:")
print(f" mAP@0.5: {metrics.box.map50:.4f}")
print(f" mAP@0.5:0.95: {metrics.box.map:.4f}")
print(f" 精确率: {metrics.box.precision:.4f}")
print(f" 召回率: {metrics.box.recall:.4f}")
# 显示混淆矩阵
try:
Image(filename=f'{metrics.save_dir}/confusion_matrix.png')
except:
pass
print("✅ 验证完成!")
return metrics
def predict(self, source, model_path=None, conf=0.25, imgsz=640):
"""
使用训练好的模型进行预测
Args:
source (str): 预测源(图片/视频/目录)
model_path (str): 模型路径
conf (float): 置信度阈值
imgsz (int): 图像尺寸
"""
if model_path:
model = YOLO(model_path)
else:
model = self.model
if model is None:
raise ValueError("没有可用的模型进行预测")
print(f"开始预测: {source}")
results = model.predict(
source=source,
conf=conf,
imgsz=imgsz,
save=True, # 保存带检测结果的图像
save_txt=False, # 保存检测结果文本
save_conf=True, # 保存置信度
save_crop=False, # 保存裁剪的检测结果
show_labels=True, # 显示标签
show_conf=True, # 显示置信度
show_boxes=True, # 显示边界框
line_width=2, # 边界框线宽
visualize=False, # 可视化模型特征
augment=False, # 测试时数据增强
agnostic_nms=False, # 类别无关NMS
retina_masks=False,
boxes=True,
device=self.device
)
print(f"✅ 预测完成! 结果保存在 {results[0].save_dir}")
return results
def export(self, model_path, format='onnx', imgsz=640):
"""
导出模型
Args:
model_path (str): 模型路径
format (str): 导出格式 ('onnx', 'torchscript', 'tflite', 'tfjs')
imgsz (int): 导出图像尺寸
"""
model = YOLO(model_path)
print(f"导出模型为 {format.upper()} 格式...")
export_params = {
'format': format,
'imgsz': imgsz,
'keras': False,
'optimize': True, # ONNX优化
'half': False, # FP32精度
'int8': False,
'dynamic': False,
'simplify': True, # ONNX简化
'opset': 12, # ONNX opset版本
'workspace': 4, # TensorRT工作空间大小(GB)
'nms': False,
'batch': 1
}
model.export(**export_params)
# 检查导出文件
export_file = model_path.replace('.pt', f'.{format}')
if os.path.exists(export_file):
print(f"✅ 模型已导出: {export_file}")
return export_file
else:
raise RuntimeError(f"模型导出失败: {export_file} 不存在")
def main():
# === 配置参数 ===
DATA_CONFIG_PATH = "path/to/your/data.yaml" # 数据集配置文件
MODEL_SIZE = 's' # 模型大小: n(nano), s(small), m(medium), l(large), x(xlarge)
EPOCHS = 100 # 训练轮数
IMGSZ = 640 # 输入图像尺寸
BATCH_SIZE = 16 # 批次大小
# 创建训练器
trainer = YOLOv8CustomTrainer(
data_config_path=DATA_CONFIG_PATH,
model_size=MODEL_SIZE
)
# === 训练模型 ===
trainer.load_model(pretrained=True) # 使用预训练权重
trainer.train(
epochs=EPOCHS,
imgsz=IMGSZ,
batch=BATCH_SIZE,
# 高级参数调整
lr0=0.01, # 初始学习率
weight_decay=0.05, # 权重衰减
dropout=0.2, # 防止过拟合
mosaic=1.0, # Mosaic数据增强概率
mixup=0.1, # MixUp数据增强概率
copy_paste=0.1, # Copy-Paste数据增强概率
name='custom_train_experiment'
)
# === 验证模型 ===
best_model_path = "runs/detect/custom_train_experiment/weights/best.pt"
trainer.validate(model_path=best_model_path, imgsz=IMGSZ)
# === 使用模型预测 ===
trainer.predict(
source="path/to/test/images",
model_path=best_model_path,
conf=0.25,
imgsz=IMGSZ
)
# === 导出模型 ===
trainer.export(
model_path=best_model_path,
format='onnx',
imgsz=IMGSZ
)
if __name__ == "__main__":
main()
这段代码我需要修改的地方,以及添加的路径进行红色标注