目录
文件下载:https://download.youkuaiyun.com/download/qq_37116150/12289213
1. 预处理标注文件
首先将全局变量定义完成:
IMGSZ = 512 # 输入图片尺寸大小,必须是512x512
GRIDSZ = 16 # 网络最后输出的尺寸, 16x16
num_classes = 3 # 标签类别,共有3类,其中一类是背景,实际就两类,可根据需要修改
# YOLO的anchors的5个尺寸
ANCHORS = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828]
为了损失函数,需要将标注信息的格式修改,应符合网络输出的格式: [b, 16, 16, 5, (4+1+num_classes)],具体代码如下:
def process_true_boxes(gt_boxes, anchors):
"""
计算一张图片的真实标签信息
:param gt_boxes:
:param anchors:
:return:
"""
# gt_boxes: [40,5] 一张真实标签的位置坐标信息,40是虚数,根据实际情况来定
# 512//16=32
# 计算网络模型从输入到输出的缩小比例
scale = IMGSZ // GRIDSZ
# [5,2] 将anchors转化为矩阵形式,一行代表一个anchors
anchors = np.array(anchors).reshape((5, 2))
# mask for object
# 用来判断该方格位置的anchors有没有目标,每个方格有5个anchors
detector_mask = np.zeros([GRIDSZ, GRIDSZ, 5, 1])
# x-y-w-h-l
# 在输出方格的尺寸上[16, 16, 5]制作真实标签, 用于和预测输出值做比较,计算损失值
matching_gt_box = np.zeros([GRIDSZ, GRIDSZ, 5, 5])
# [40,5] x1-y1-x2-y2-l => x-y-w-h-l
# 制作一个numpy变量,用于存储一张图片真实标签转换格式后的数据
# 将左上角与右下角坐标转化为中心坐标与宽高的形式
# [x_min, y_min, x_max, y_max] => [x_center, y_center,