简介
在Datawhale AI夏令营第五期 Task2 中,将进一步探索YOLO算法。通过学习以掌握YOLO在实际项目中的应用,包括数据集准备、模型训练、评估和部署。同时,根据文档所提出的上分方案,修改代码,尝试提升模型表现。
一、YOLO模型介绍
1.目标监测任务
2.YOLO模型
3.YOLO模型的数据集格式
4.YOLO模型的训练日志
二、进阶方法及尝试
1.增加数据集数量

从BaseLine代码中看到,训练数据仅使用了前10个视频生成训练数据集。而比赛主办方提供的数据总共有53个文件,因此可以将所有视频用来制作数据集。
1.数据可视化
我们首先将训练数据进行可视化:
import glob
import os.path
import sys
import cv2
from matplotlib import pyplot as plt
# 颜色表
category_dic={
0: (0, 255, 0),
1: (255, 0, 0),
2: (0, 0, 255),
3: (255, 0, 255)
}
def get_pts(label, img_height, img_width):
'''
根据label 获取每个矩形框的位置
'''
label = label.split(" ")
category_idx, x_center, y_center, width, height = label[0], label[1], label[2], label[3], label[4]
x_center = float(x_center)
y_center = float(y_center)
width = float(width)
height = float(height)
x_max = (2 * x_center * img_width + img_width * width) / 2.0
x_min = x_max - img_width * width
y_max = (img_height * height + 2 * img_height * y_center) / 2.0
y_min = y_max - img_height * height
pt1 = (int(x_min), int(y_min))
pt2 = (int(x_max), int(y_max))
return category_idx ,[pt1, pt2]
def draw_rectangle(pic_path, label_path, out):
'''
绘制矩形
'''
if os.path.exists(out):
return
pic = cv2.imread(pic_path, cv2.IMREAD_COLOR)
img_height, img_width = pic.shape[0], pic.shape[1]
category_idxs = []
pts = []
# 打开文件并读取内容
with open(label_path, 'r') as file:
labels = file.readlines()
for label in (labels):
cat_id, pt_list = get_pts(label, img_height, img_width)
category_idxs.append(int(cat_id))
pts.append(pt_list)
for i in range(len(pts)):
color = category_dic[category_idxs[i]]
thickness = 2 # 线条粗细
pt = pts[i]
cv2.rectangle(pic, pt[0], pt[1], color, thickness)
# frame = cv2.cvtColor(pic, cv2.COLOR_BGR2RGB)
cv2.imencode('.png', pic)[1].tofile(out)
if __name__ == '__main__':
out_path = r"你的路径"
imgs = glob.glob('你的路径')
for img in imgs:
# 标签路径
label = img.replace(".jpg", ".txt")
# 输出文件名
out_= os.path.join(out_path, os.path.basename(label).replace(".txt", ".png"))
draw_rectangle(img, label, out_)
发现数据集中具有较多高一致性的图片:

因此在生产数据集的时候,进行抽帧。由于数据集为视频,因此帧速率为25帧/秒,因此每30帧采样一次。
2.数据抽帧
BaseLine 代码:
# 训练集代码
for anno_path, video_path in zip(train_annos[:5], train_videos[:5]):
print(video_path)
anno_df = pd.read_json(anno_path)
cap = cv2.VideoCapture(video_path)
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
img_height, img_width = frame.shape[:2]
frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
cv2.imwrite('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)
if len(frame_anno) != 0:
with open('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:
for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):
category_idx = category_labels.index(category)
x_min, y_min, x_max, y_max = bbox
x_center = (x_min + x_max) / 2 / img_width
y_center = (y_min + y_max) / 2 / img_height
width = (x_max - x_min) / img_width
height = (y_max - y_min) / img_height
if x_center > 1:
print(bbox)
up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')
frame_idx += 1
# 测试集代码
for anno_path, video_path in zip(train_annos[-3:], train_videos[-3:]):
print(video_path)
anno_df = pd.read_json(anno_path)
cap = cv2.VideoCapture(video_path)
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
img_height, img_width = frame.shape[:2]
frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
cv2.imwrite('./yolo-dataset/val/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)
if len(frame_anno) != 0:
with open('./yolo-dataset/val/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:
for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):
category_idx = category_labels.index(category)
x_min, y_min, x_max, y_max = bbox
x_center = (x_min + x_max) / 2 / img_width
y_center = (y_min + y_max) / 2 / img_height
width = (x_max - x_min) / img_width
height = (y_max - y_min) / img_height
up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')
frame_idx += 1
按照需求,每30帧抽取一帧,修改后代码:
for anno_path, video_path in zip(train_annos[:-10], train_videos[:-10]):
print(video_path)
anno_df = pd.read_json(anno_path)
cap = cv2.VideoCapture(video_path)
frame_idx = 0
count = 0
while True:
ret, frame = cap.read()
if not ret:
break
img_height, img_width = frame.shape[:2]
frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
# 每30帧进行保存影像,进行抽帧操作
if frame_idx % 30 == 0:
cv2.imwrite('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)
if len(frame_anno) != 0:
with open('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt',
'w') as up:
for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):
category_idx = category_labels.index(category)
x_min, y_min, x_max, y_max = bbox
x_center = (x_min + x_max) / 2 / img_width
y_center = (y_min + y_max) / 2 / img_height
width = (x_max - x_min) / img_width
height = (y_max - y_min) / img_height
if x_center > 1:
print(bbox)
up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')
frame_idx += 1
for anno_path, video_path in zip(train_annos[-10:], train_videos[-10:]):
print(video_path)
anno_df = pd.read_json(anno_path)
cap = cv2.VideoCapture(video_path)
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
img_height, img_width = frame.shape[:2]
frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
# 每20帧进行保存影像,进行抽帧操作
if frame_idx % 20 == 0:
cv2.imwrite('./yolo-dataset/val/' + anno_path.split('\\')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)
if len(frame_anno) != 0:
with open('./yolo-dataset/val/' + anno_path.split('\\')[-1][:-5] + '_' + str(frame_idx) + '.txt',
'w') as up:
for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):
category_idx = category_labels.index(category)
x_min, y_min, x_max, y_max = bbox
x_center = (x_min + x_max) / 2 / img_width
y_center = (y_min + y_max) / 2 / img_height
width = (x_max - x_min) / img_width
height = (y_max - y_min) / img_height
up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')
frame_idx += 1
最后生成的训练集包括2060景图片,验证集包括509景图片。
2.切换不同权重
1.配置本地环境
由于在线平台算力有限,因此尝试在本地服务器中部署环境。
首先,通过conda方式安装ultralytics,较慢,于是放弃。

然后,尝试从github中下载源码,放置在工程文件目录,尝试运行代码,发现可以运行。

耶,能够运行!

2.测试模型
炼丹ing

被折叠的 条评论
为什么被折叠?



