MaskFormer源码整理

深度解析MaskFormer模型:结构、核心组件与损失函数
本文详细介绍了MaskFormer模型的组成部分,包括backbone(使用ResNet或Swin-Transformer)、MaskFormerHead中的PixelDecoder和TransformerPredictor。PixelDecoder实现了FPN结构并结合Transformer处理特征,TransformerPredictor则包含类别预测和实例掩模预测两个分支。损失函数采用了匈牙利算法进行匹配和计算,同时结合了focalloss和diceloss。

前言

按照计划对MaskFormer进行了复现,配置了模型所需的环境,在实验室服务器上跑通了训练脚本。

对源码中模型搭建的部分进行学习,结合原论文深入理解整体结构。

1. 模型概述

整个代码结构基于detectron2框架,所以会有很多注册的指令和from_config()函数,这两个都不影响代码的逻辑,在看源码的时候不必纠结,把所有的from_config()都看成从配置文件读取相关变量的值即可,具体的值可以在config/xx/xx.yaml文件种找到;注册指令是为了detectron2可以检测到,看源码的时候可以直接忽略这条指令。

总共四个核心类分别是:

  • mask_former/mask_former_model.py中的MaskFormer
  • mask_former/modeling/heads/mask_former_head.py中的MaskFormerHead
  • mask_former/modeling/heads/pixel_decoder.py中的TransformerEncoderPixelDecoder
  • mask_former/modeling/transformer/transformer_predictor.py中的TransformerPredictor
@META_ARCH_REGISTRY.register()
class MaskFormer(nn.Module):
    """
    Main class for mask classification semantic segmentation architectures.
    """

    @configurable
    def __init__(
        self,
        *,
        backbone: Backbone,
        sem_seg_head: nn.Module,
        criterion: nn.Module,
        num_queries: int,
        panoptic_on: bool,
        object_mask_threshold: float,
        overlap_threshold: float,
        metadata,
        size_divisibility: int,
        sem_seg_postprocess_before_inference: bool,
        pixel_mean: Tuple[float],
        pixel_std: Tuple[float],
    ):

其中,MaskFormer类是模型的入口,整体结构都在这个类中。在这个类的构造函数中可以看到模型是由backbone、sem_seg_head、criterion构成,通过debug,可以打印出这三个参数所对应的模块。其中backbone对应下图中的backbone,sem_seg_head包含了除backbone外的所有部分,criterion是损失函数的类,对应下图中的两个Loss。

2. backbone

作者提供了resnet和swin-transformer两种backbone,输出outs[“res{}”.format(i + 2)] = out,是一个字典,关键字是res2、res3等。

    def forward(self, x):
        """Forward function."""
        x = self.patch_embed(x)

        Wh, Ww = x.size(2), x.size(3)
        if self.ape:
            # interpolate the position embedding to the corresponding size
            absolute_pos_embed = F.interpolate(
                self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
            )
            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
        else:
            x = x.flatten(2).transpose(1, 2)
        x = self.pos_drop(x)

        outs = {
   
   }
        for i in range(self.num_layers):
            layer = self.layers[i]
            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)

            if i in self.out_indices:
                norm_layer = getattr(self, f"norm{
     
     i}")
                x_out = norm_layer(x_out)

                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
                outs["res{}".format(i + 2)] = out

        
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值