1. 引言
在 YOLOv5 的实现中,Anchor 是一个重要的概念。它不仅影响着模型的训练过程,还在推理阶段对预测结果的解码有着重要作用。很多人可能会有疑问:为什么在 YOLOv5 的 Detect 类中使用 register_buffer(‘anchors’, …) 注册了 Anchors,却在训练过程中似乎没有直接用到它?本篇文章将详细探讨 YOLOv5 中 Anchor 在训练与推理时的区别,及其设计背后的逻辑。
2. YOLOv5 中的 Anchor 使用简介
在 YOLOv5 中,Detect 类通过 register_buffer 将 Anchor 注册为模型的缓冲区。这种方式确保了 Anchor 不会被视为模型的可训练参数,并且在模型转移到不同设备或者保存和加载时,Anchor 的值能保持一致。
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
Anchors 在推理阶段用于解码网络输出的特征图,生成实际的边界框。虽然在训练过程中 Anchors 似乎没有被直接使用,但它们在目标分配和损失计算中起到了至关重要的作用。
3. 训练阶段中的 Anchor 使用
在训练阶段,模型的主要任务是通过前向传播(forward pass)生成预测,并将这些预测与真实标签进行对比来计算损失。在这一过程中,模型输出的是相对于 Anchor 的偏移量,而不是实际的边界框。这种设计的好处是:模型可以学习相对位置,而不是绝对坐标,从而更好地适应不同尺度的物体。
在 YOLOv5 中,Anchor 的具体使用发生在损失函数中。例如,损失函数可能会使用 Anchor 来进行目标匹配和计算损失:
目标匹配:根据 Anchor 的大小和位置,将模型的预测框与真实框进行匹配。
损失计算:损失函数使用 Anchor 的信息来解码预测框,并将其与真实标签进行对比。
注意:这些操作并不在 Detect 类中进行,而是在损失函数(例如 ComputeLoss 类)中完成。因此,虽然 Anchor 在训练过程中确实有被使用,但它们的作用在 Detect 层之外。
4. 推理阶段中的 Anchor 使用
在推理阶段,模型的目标是将输出特征图解码为实际的边界框。因此,Detect 层在推理过程中需要使用注册的 anchors 缓冲区来计算每个检测层的网格和 Anchor 网格(anchor_grid)。这些网格帮助模型将预测的相对偏移量转换为实际坐标。
以下是 Detect 类的代码示例,展示了如何在推理阶段使用 Anchor:
def forward(self, x):
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
# 解码预测框
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1),