目录
《深入理解 mmdet 中的 Associative Embedding Loss》
在目标检测领域,损失函数的设计对于模型的性能至关重要。mmdet(OpenMMLab Detection Toolbox)提供了多种强大的损失函数,其中AssociativeEmbeddingLoss
是一个值得深入探讨的损失函数。
一、引入必要的库和模块
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.registry import MODELS
这里导入了 PyTorch 的相关模块以及 mmdet 中的MODELS
注册器。
二、ae_loss_per_image
函数解析
def ae_loss_per_image(tl_preds, br_preds, match):
"""Associative Embedding Loss in one image."""
这个函数计算单个图像的关联嵌入损失(Associative Embedding Loss)。
-
三种情况判断:
- 如果图像中没有对象(
len(match) == 0
),则将拉损失(pull loss)和推损失(push loss)都设置为 0。 - 如果图像中有一个对象,推损失为 0,拉损失由该对象的两个角(左上角和右下角)的嵌入向量计算得出。
- 如果图像中有多个对象,拉损失由每个对象的角对计算得出,推损失由每个对象与其他所有对象计算得出,使用对角线为 0 的混淆矩阵来计算推损失。
- 如果图像中没有对象(
-
对象存在时的处理:
- 遍历匹配列表
match
,提取每个对象的左上角和右下角的嵌入向量,并计算它们的均值向量。 - 对于拉损失,计算左上角和右下角向量与均值向量的平方差之和,并除以对象数量
N
。 - 对于推损失,计算混淆矩阵,其中元素是两个不同对象的均值向量之差的绝对值与边缘值(这里设置为 1)的差值。如果有多个对象(
N > 1
),则将混淆矩阵中的正值进行求和并除以对象数量的组合数N * (N - 1)
。
tl_list, br_list, me_list = [], [], [] if len(match) == 0: pull_loss = tl_preds.sum() * 0. push_loss = tl_preds.sum() * 0. else: for m in match: [tl_y, tl_x], [br_y, br_x] = m tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1) br_e = br_preds[:, br_y, br_x].view(-1, 1) tl_list.append(tl_e) br_list.append(br_e) me_list.append((tl_e + br_e) / 2.0) tl_list = torch.cat(tl_list) br_list = torch.cat(br_list) me_list = torch.cat(me_list) assert tl_list.size() == br_list.size() N, M = tl_list.size() pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2) pull_loss = pull_loss.sum() / N margin = 1 conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list conf_weight = 1 - torch.eye(N).type_as(me_list) conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs()) if N > 1: push_loss = F.relu(conf_mat).sum() / (N * (N - 1)) else: push_loss = tl_preds.sum() * 0.
最终,函数返回拉损失和推损失。
- 遍历匹配列表
三、AssociativeEmbeddingLoss
类解析
@MODELS.register_module()
class AssociativeEmbeddingLoss(nn.Module):
"""Associative Embedding Loss."""
这个类是AssociativeEmbeddingLoss
的实现,继承自 PyTorch 的nn.Module
类。
-
初始化:
- 在构造函数中,接收拉损失权重
pull_weight
和推损失权重push_weight
,并将它们存储在实例变量中。
def __init__(self, pull_weight=0.25, push_weight=0.25): super(AssociativeEmbeddingLoss, self).__init__() self.pull_weight = pull_weight self.push_weight = push_weight
- 在构造函数中,接收拉损失权重
-
前向传播:
- 在
forward
函数中,遍历输入的批次,对每个图像调用ae_loss_per_image
函数计算拉损失和推损失。 - 然后将每个图像的损失乘以对应的权重,并在批次维度上进行累加。
- 最终返回拉损失和推损失的总和。
def forward(self, pred, target, match): """Forward function.""" batch = pred.size(0) pull_all, push_all = 0.0, 0.0 for i in range(batch): pull, push = ae_loss_per_image(pred[i], target[i], match[i]) pull_all += self.pull_weight * pull push_all += self.push_weight * push return pull_all, push_all
- 在
AssociativeEmbeddingLoss
通过计算拉损失和推损失,鼓励来自同一对象的角点嵌入向量靠近,同时使不同对象的角点嵌入向量远离,从而提高目标检测模型对对象边界的定位准确性。