〖MMDetection〗解析文件:/models/losses/ae_loss.py

《深入理解 mmdet 中的 Associative Embedding Loss》

在目标检测领域,损失函数的设计对于模型的性能至关重要。mmdet(OpenMMLab Detection Toolbox)提供了多种强大的损失函数,其中AssociativeEmbeddingLoss是一个值得深入探讨的损失函数。

一、引入必要的库和模块

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmdet.registry import MODELS

这里导入了 PyTorch 的相关模块以及 mmdet 中的MODELS注册器。

二、ae_loss_per_image函数解析

def ae_loss_per_image(tl_preds, br_preds, match):
    """Associative Embedding Loss in one image."""

这个函数计算单个图像的关联嵌入损失(Associative Embedding Loss)。

  1. 三种情况判断

    • 如果图像中没有对象(len(match) == 0),则将拉损失(pull loss)和推损失(push loss)都设置为 0。
    • 如果图像中有一个对象,推损失为 0,拉损失由该对象的两个角(左上角和右下角)的嵌入向量计算得出。
    • 如果图像中有多个对象,拉损失由每个对象的角对计算得出,推损失由每个对象与其他所有对象计算得出,使用对角线为 0 的混淆矩阵来计算推损失。
  2. 对象存在时的处理

    • 遍历匹配列表match,提取每个对象的左上角和右下角的嵌入向量,并计算它们的均值向量。
    • 对于拉损失,计算左上角和右下角向量与均值向量的平方差之和,并除以对象数量N
    • 对于推损失,计算混淆矩阵,其中元素是两个不同对象的均值向量之差的绝对值与边缘值(这里设置为 1)的差值。如果有多个对象(N > 1),则将混淆矩阵中的正值进行求和并除以对象数量的组合数N * (N - 1)
    tl_list, br_list, me_list = [], [], []
    if len(match) == 0:
        pull_loss = tl_preds.sum() * 0.
        push_loss = tl_preds.sum() * 0.
    else:
        for m in match:
            [tl_y, tl_x], [br_y, br_x] = m
            tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
            br_e = br_preds[:, br_y, br_x].view(-1, 1)
            tl_list.append(tl_e)
            br_list.append(br_e)
            me_list.append((tl_e + br_e) / 2.0)
    
        tl_list = torch.cat(tl_list)
        br_list = torch.cat(br_list)
        me_list = torch.cat(me_list)
    
        assert tl_list.size() == br_list.size()
    
        N, M = tl_list.size()
    
        pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
        pull_loss = pull_loss.sum() / N
    
        margin = 1
    
        conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
        conf_weight = 1 - torch.eye(N).type_as(me_list)
        conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
    
        if N > 1:
            push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
        else:
            push_loss = tl_preds.sum() * 0.
    

    最终,函数返回拉损失和推损失。

三、AssociativeEmbeddingLoss类解析

@MODELS.register_module()
class AssociativeEmbeddingLoss(nn.Module):
    """Associative Embedding Loss."""

这个类是AssociativeEmbeddingLoss的实现,继承自 PyTorch 的nn.Module类。

  1. 初始化

    • 在构造函数中,接收拉损失权重pull_weight和推损失权重push_weight,并将它们存储在实例变量中。
    def __init__(self, pull_weight=0.25, push_weight=0.25):
        super(AssociativeEmbeddingLoss, self).__init__()
        self.pull_weight = pull_weight
        self.push_weight = push_weight
    
  2. 前向传播

    • forward函数中,遍历输入的批次,对每个图像调用ae_loss_per_image函数计算拉损失和推损失。
    • 然后将每个图像的损失乘以对应的权重,并在批次维度上进行累加。
    • 最终返回拉损失和推损失的总和。
    def forward(self, pred, target, match):
        """Forward function."""
        batch = pred.size(0)
        pull_all, push_all = 0.0, 0.0
        for i in range(batch):
            pull, push = ae_loss_per_image(pred[i], target[i], match[i])
    
            pull_all += self.pull_weight * pull
            push_all += self.push_weight * push
    
        return pull_all, push_all
    

AssociativeEmbeddingLoss通过计算拉损失和推损失,鼓励来自同一对象的角点嵌入向量靠近,同时使不同对象的角点嵌入向量远离,从而提高目标检测模型对对象边界的定位准确性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值