NanoDet代码逐行精读与修改(五.2)计算Loss

neozng1@hnu.edu.cn

详细地介绍了Loss计算的重难点,精细到给出了每一个tensor的维度和其中每一个元素的意义!

5.2.标签分配和Loss计算

5.2.1. 计算Loss的模块和流程

loss的运算流程如下,当aux_headAGM启用的时候,aux_headfpnaux_fpn获取featmap随后输出预测,在detach_epoch(需要自己设置的参数,在训练了detach_epoch后标签分配将由检测头自己进行)内,使用AGM的输出来对head的预测值进行标签分配。

先根据输入的大小获取priors的网格,然后将AGM预测的分数和检测框根据prior进行标签分配,把分配结果提交给head,再使用head的输出和gt计算loss最后反向传播完成一步迭代。

5.2.2. 计算Loss

这里我们先“假装”已经知道封装好的函数在做什么,先关注整体的计算流程:

def loss(self, preds, gt_meta, aux_preds=None):
    """Compute losses.
    Args:
        preds (Tensor): Prediction output. head的输出
        gt_meta (dict): Ground truth information. 包含gt位置和标签的字典,还有原图像的数据
        aux_preds (tuple[Tensor], optional): Auxiliary head prediction output.
        如果AGM还没有detach,会用AGM的输出进行标签分配

    Returns:
        loss (Tensor): Loss tensor.
        loss_states (dict): State dict of each loss.
    """
    # 把gt相关的数据分离出来,这两个数据都是list,长度为batchsize的大小
    # 每个list中都包含了它们各自对应的图像上的gt和label
    gt_bboxes = gt_meta["gt_bboxes"]
    gt_labels = gt_meta["gt_labels"]
    # 一会要传送到GPU上
    device = preds.device
    # 得到本次loss计算的batch数,pred是3维tensor,可以参见第一部分关于推理的介绍
    batch_size = preds.shape[0]

    # 对"img"信息取shape就可以得到图像的长宽,这里请看DataSet和DataLoader了解训练数据的详细格式
    # 所有图片都会在前处理中被resize成网络的输入大小,不足的则直接加zero padding
    input_height, input_width = gt_meta["img"].shape[2:]

    # 因为稍后要布置priors,这里要计算出feature map的大小
    # 如果修改了输入或者采样率,输入无法被stride整除,要取整
    # 默认是对齐的,应该为[40,40],[20,20],[10,10],[5,5]
    featmap_sizes = [
        (math.ceil(input_height / stride), math.ceil(input_width) / stride)
        for stride in self.strides
    ]

    # get grid cells of one image
    # 在不同大小的stride上放置一组priors,默认四个检测头也就是四个不同尺寸的stride
    # 最后返回的是tensor维度是[batchsize,strideW*strideH,4]
    # 其中每一个都是[x,y,strideH,strideW]的结构,当featmap不是正方形的时候两个stride不相等
    mlvl_center_priors = [
        self.get_single_level_center_priors(
            batch_size,
            featmap_sizes[i],
            stride,
            dtype=torch.float32,
            device=device,
        )
        for i, stride in enumerate(self.strides)
    ]

    # 按照第二个维度拼接后的prior的维度是[batchsize,40x40+20x20+10x10+5x5=2125,4]
    # 其中四个值为[cx,cy,strideW,stridH]
    center_priors = torch.cat(mlvl_center_priors, dim=1)

    # 把预测值拆分成分类和框回归
    # cls_preds的维度是[batchsize,2125*class_num],reg_pred是[batchsize,2125*4*(reg_max+1)]
    cls_preds, reg_preds = preds.split(
        [self.num_classes, 4 * (self.reg_max + 1)], dim=-1
    )

    # 对reg_preds进行积分得到位置预测,reg_reds表示的是一条边的离散分布,
    # 积分就得到位置(对于离散来说就是加权求和),distribution_project()是Integral函数,稍后讲解
    # 乘以stride(在center_priors的最后两个位置
评论 8
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值