目标检测网络的训练大致是如下的流程:
- 设置各种超参数
- 定义数据加载模块 dataloader
- 定义网络 model
- 定义损失函数 loss
- 定义优化器 optimizer
- 遍历训练数据,预测-计算loss-反向传播
首先,我们导入必要的库,然后设定各种超参数
后处理
目标框信息解码
之前我们的提到过,模型不是直接预测的目标框信息,而是预测的基于anchor的偏移,且经过了编码。因此后处理的第一步,就是对模型的回归头的输出进行解码,拿到真正意义上的目标框的预测结果。
后处理还需要做什么呢?由于我们预设了大量的先验框,因此预测时在目标周围会形成大量高度重合的检测框,而我们目标检测的结果只希望保留一个足够准确的预测框,所以就需要使用某些算法对检测框去重。这个去重算法叫做NMS
NMS的大致算法步骤如下:
-
按照类别分组,依次遍历每个类别。
-
当前类别按分类置信度排序,并且设置一个最低置信度阈值如0.05,低于这个阈值的目标框直接舍弃。
-
当前概率最高的框作为候选框,其它所有与候选框的IOU高于一个阈值(自己设定,如0.5)的框认为需要被抑制,从剩余框数组中删除。
-
然后在剩余的框里寻找概率第二大的框,其它所有与第二大的框的IOU高于设定阈值的框被抑制。
-
依次类推重复这个过程,直至遍历完所有剩余框,所有没被抑制的框即为最终检测框。