如何使用深度学习框架(YOLOv8)对RTTS雾数据集和RESIDE数据集进行训练及推理和可视化
RTTS雾数据集(yolo格式的RTTS的txt标签) RESIDE数据集(RTTS、OTS、ITS、SOTS、HSTS、Unannotated Real-worls Hazy)
1
1
1
使用深度学习框架(如YOLOv8)对RTTS雾数据集和RESIDE数据集中的其他子集进行训练、推理及可视化,我们需要遵循以下步骤:准备数据集、配置模型参数、训练模型、评估模型性能以及推理和结果可视化。以下是详细的指南和代码示例。
1. 数据集准备
假设你的数据集结构如下:
RESIDE/
├── RTTS/
│ ├── images/
│ │ ├── train/
│ │ └── val/
│ └── labels/
│ ├── train/
│ └── val/
├── OTS/
│ ├── images/
│ │ ├── train/
│ │ └── val/
│ └── labels/
│ ├── train/
│ └── val/
└── ITS/
├── images/
│ ├── train/
│ └── val/
└── labels/
├── train/
└── val/
data_RESIDE.yaml
data_RESIDE.yaml
文件内容示例:
train: ./RESIDE/RTTS/images/train/
val: ./RESIDE/RTTS/images/val/
nc: 1 # 假设只有一类目标
names: ['target_class']
请根据实际情况调整类别数量和名称。
2. 安装依赖库
确保安装了必要的库:
pip install ultralytics opencv-python-headless tensorboard
3. 模型训练
创建一个Python脚本来开始训练过程。这里我们以YOLOv8为例说明如何训练模型。
训练脚本
from ultralytics import YOLO
def main_train():
# 加载预训练的YOLOv8n模型或从头开始定义模型
model = YOLO('yolov8n.yaml') # 或者直接加载预训练权重,如 'yolov8n.pt'
results = model.train(
data='./data_RESIDE.yaml',
epochs=100, # 根据需要调整
imgsz=640,
batch=16,
project='./runs/detect',
name='RESIDE_detection',
optimizer='SGD',
device='0', # 使用GPU编号
save=True,
cache=True,
verbose=True,
)
if __name__ == '__main__':
main_train()
4. 推理与结果可视化
训练完成后,我们可以利用训练好的模型对新图片进行预测,并将结果可视化。
推理脚本
import cv2
from PIL import Image
from ultralytics import YOLO
model = YOLO('./runs/detect/RESIDE_detection/weights/best.pt')
def detect_objects(image_path):
results = model.predict(source=image_path)
img = cv2.imread(image_path)
for result in results:
boxes = result.boxes.numpy()
for box in boxes:
r = box.xyxy
x1, y1, x2, y2 = int(r[0]), int(r[1]), int(r[2]), int(r[3])
label_id = int(box.cls)
label = result.names[label_id]
confidence = box.conf
if confidence > 0.5: # 设置置信度阈值
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) # 绘制矩形框
cv2.putText(img, f'{label} {confidence:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
return img
# 示例调用
result_image = detect_objects('your_test_image.jpg') # 确保测试图像路径正确
Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)).show() # 使用PIL显示图像
5. 可视化界面
为了监控训练过程,可以使用TensorBoard。在训练脚本中添加 tensorboard=True
参数,然后运行以下命令启动TensorBoard:
tensorboard --logdir runs/
然后在浏览器中访问 http://localhost:6006
查看训练进度和结果。