yolov5的anchors的编解码原理
yolov5的anchors及bbox的编解码原理
1.anchor的生成
anchor的生成分为两个步骤:
首先会生成位于feature map左上角的base anchor,众所周知yolov3有9组anchor(图2),3个不同尺度大小的输出,而3个输出各对应三组anchor;生成base anchor之后会根据不同feature map的stride对base anchor进行平移复制。
1)base anchor的生成
现假设输入图片为448640,那么最后输出的三个featuremap大小从小到大依次为:1420、2840、5680,对应的stride从小到大依次为:32,16,8。**以1420的feature map为例**,stride=32,对应9组anchor中的第一组:[(116, 90), (156, 198), (373, 326)]。那么图1中的代码就干两件事儿,生成base anchor:
2)base anchor的平移和复制
平移复制base anchor,stride=32(feature map为14*20,那么相对于原图一个网格边长就是32):
最终,在14×20的feature map上有14×20×3=840个anchor,在28×40的feature map上有28×40×3=3360个anchor,在5680的feature map上有5680*3=13440个anchor,共计17640个anchor。
anchor生成完毕之后,再做一些操作之后就是decoder了,decode完毕之后就是NMS
2.bbox的编解码过程
编解码代码:
@BBOX_CODERS.register_module()
class YOLOV5BBoxCoder(BaseBBoxCoder):
"""YOLO BBox coder.
Following `YOLO <https://arxiv.org/abs/1506.02640>`_, this coder divide
image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh).
cx, cy in [0., 1.], denotes relative center position w.r.t the center of
bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`.
Args:
eps (float): Min value of cx, cy when encoding.
"""
def __init__(self, eps=1e-6):
super(BaseBBoxCoder, self).__init__()
self.eps = eps
def encode(self, bboxes, gt_bboxes, stride):
"""
YOLOv5 and YOLOx of YOLO series don't have encode
"""
raise NotImplementedError("YOLOv5 doesn't have encoder!")
# target
def delta_bbox(self, bboxes, pred_bboxes):
"""Get delta_bboxes from anchors and pred_bboxes.
Args:
bboxes: anchors.(x1,y1,x2,y2)
pred_bboxes: output bboxes of YOLOv5.(x,y,w,h)
"""
# anchors' width and height
w, h = bboxes[..., 2] - bboxes[..., 0], bboxes[..., 3] - bboxes[..., 1]
# center of x & y of pred_bbox
x_center_pred = (pred_bboxes[..., 0].sigmoid() - 0.5) * 2
y_center_pred = (pred_bboxes[..., 1].sigmoid() - 0.5) * 2
# w & h of pred_bbox
w_pred = (pred_bboxes[..., 2].sigmoid() * 2) ** 2 * w
h_pred = (pred_bboxes[..., 3].sigmoid() * 2) ** 2 * h
delta_bboxes = torch.stack(
(x_center_pred - w_pred / 2, y_center_pred - h_pred / 2,
x_center_pred + w_pred / 2, y_center_pred + h_pred / 2),
dim=-1
)
return delta_bboxes
def decode(self, bboxes, pred_bboxes, stride):
"""Apply transformation `pred_bboxes` to `boxes`.
Args:
boxes (torch.Tensor): Basic boxes, e.g. anchors.
pred_bboxes (torch.Tensor): Encoded boxes with shape
stride (torch.Tensor | int): Strides of bboxes.
Returns:
torch.Tensor: Decoded boxes.
"""
assert pred_bboxes.size(0) == bboxes.size(0)
assert pred_bboxes.size(-1) == bboxes.size(-1) == 4
x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5
y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5
w = bboxes[..., 2] - bboxes[..., 0]
h = bboxes[..., 3] - bboxes[..., 1]
# Get outputs x, y
# 由于mmdetection的anchor已经偏移了0.5,故*2的操作要放在外面
x_center_pred = (pred_bboxes[..., 0] - 0.5) * 2 * stride + x_center
y_center_pred = (pred_bboxes[..., 1] - 0.5) * 2 * stride + y_center
# yolov5中正常情况应该是
# x_center_pred = (pred_bboxes[..., 0] * 2. - 0.5 + grid[:, 0]) * stride # xy
# y_center_pred = (pred_bboxes[..., 1] * 2. - 0.5 + grid[:, 1]) * stride # xy
# wh也需要sigmoid,然后乘以4来还原
w_pred = (pred_bboxes[..., 2].sigmoid() * 2) ** 2 * w
h_pred = (pred_bboxes[..., 3].sigmoid() * 2) ** 2 * h
decoded_bboxes = torch.stack(
(x_center_pred - w_pred / 2, y_center_pred - h_pred / 2,
x_center_pred + w_pred / 2, y_center_pred + h_pred / 2),
dim=-1)
return decoded_bboxes
bbox框的编解码原理
其实整个过程可以理解成anchor在起什么作用,下面就简单的说明一下。我们知道yolo直接去预测出目标的位置坐标是不现实的,所以yolo的作者就提出去预测gt和anchor的偏移量。
示意图说明:
上图中我们要先得到Anchor中心点所在网格的左上角坐标a=(6,2),再得到人工标注的中心点坐标g=(6.3,3.3),我们取到点g相对于点a的偏移量,记作p=(0.3,1.3),以p点坐标为中心,宽高与人工标注相同,得到target框如下:
pxy = pxy.sigmoid() * 2 - 0.5 # pxy即网络输出的框的xy坐标
pwh = (pwh.sigmoid() * 2) ** 2 * anchor[i] # anchor[i]即对应anchor的宽高
# 可以看出来,上述变化将xy坐标约束到区间[-0.5,1.5],将wh约束到[0,4]*anhcor_wh
将encode应用到上图中的网络输出pred_bbox和anchor,设pred_bbox中心点为r=(3.3,5.2),w=1.9,h=2.2,经过上述两行代码,r=(3.3,5.2)变成(1.428, 1.489),w=1.9变成1.89,h=2.2变成2.43,将得到的新框记作delta,如下:
至此,encode就完成了,一开始我们的网络没有经过训练,框都是乱跑的,经过编码之后约束到一个区间。算法会将上图中的delta框和target送入损失函数中,让网络去学习、去更新参数,使得delta框与target框趋于重合。我们假设网络学习效果非常好,现在的delta框已经与target框几乎重合:
注意了!理一下逻辑。这个时候,我们网络的输出pred_bboxes经过encode之后得到的框(记作res)几乎就是target框,还记得之前的Anchor中心点所在网格的左上角坐标a=(6,2)吗,加到res上不就是人工标注嘛!!
至此就结束了。这边可能有一点绕,因为我觉得下面的逻辑更加合适: