【Segmentation】CTDNet 2021

Paper:Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection | Proceedings of the 29th ACM International Conference on Multimedia

Code:

iCVTEAM/CTDNet: Code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection" (github.com)

"Title":2021 ACM MM oral,北航和鹏城实验室

💡Whats New

  • an Efficient and Effective SOD model
  • 网络结构创新:
    • U-shape → Trilateral Decoder(layer by layer → “coarse-fine-finer”)

    • U-shape的不足: ① 底层特征分辨率较大,导致计算复杂度大速度慢  ② 空间信息在下采样中丢失,仅合并层级特征很难恢复  ③ 高层语义信息被稀释,由于局限的感受野,全局上下文被忽略,导致欠分割  ④ 缺乏边界信息,导致边界分割质量差
  • 三个分支的Complementary体现在:  ① 来自encoder的不同阶段,且encoder共享  ② 提出了三种融合模块

⌨️Paper with Code

Encoder

  • Backbone使用的是在ImageNet上预训练的ResNet-50/ResNet-18
  • 舍弃最底层Stage-1,以加快计算速度

Decoder

Semantic Path

初始"coarse"阶段:通过较大的感受野,捕获丰富的语义上下文和全局上下文,获得显著目标的准确位置,粗糙的边界。

  • 通道注意力机制:backbone尾部嵌入Global Average Pooling (GAP),数据最后两维降为1×1(相当于通道权重)
  • 加快计算速度:对Stage-4和Stage-5使用1×1卷积+BN+ReLU (conv-1) 将通道数限制在64
  • 局部U-shape结构:使用FFM融合GAP的上采样和最高两层(Stage-4和Stage-5)的特征
  • FFM的融合策略:与相加或拼接相比,相乘运算可以避免冗余信息,抑制背景噪声,且更快

# Feature Fusion Module
class FFM(nn.Module):
    def __init__(self, channel):
        '''两层卷积'''
        super(FFM, self).__init__()
        self.conv_1 = conv3x3(channel, channel)
        self.bn_1 = nn.BatchNorm2d(channel)
        self.conv_2 = conv3x3(channel, channel)
        self.bn_2 = nn.BatchNorm2d(channel)

    def forward(self, x_1, x_2):
        out = x_1 * x_2  # x1和x2的特征融合即相乘
        out = F.relu(self.bn_1(self.conv_1(out)), inplace=True)
        out = F.relu(self.bn_2(self.conv_2(out)), inplace=True)
        return out

    def initialize(self):
        weight_init(self)

Special Path

 相对"fine"阶段:保留更多的空间细节,获得显著目标的精确结构,较为精确的边界。

  • 下采样8倍的特征:丰富的空间细节
  • SAM空间注意力机制:在通道维度做全局平均池化和全局最大池化,合并这两个通道并卷积为单通道,+sigmoid函数映射到(0,1)作为注意力权重与所有通道的元素相乘,卷积至64通道

  • CAM的融合策略:针对互补的Dp1和Dp2,上下采样(Dp2分辨率是Dp1的两倍,分别bilinear插值至对方的分辨率),交叉融合(两种分辨率分别做积并卷积,将低分辨率的特征上采样再用FFM融合),以增强特征表达

# Spatial Attention Module
class SAM(nn.Module):
    def __init__(self, in_chan, out_chan):
        super(SAM, self).__init__()
        self.conv_atten = conv3x3(2, 1)
        self.conv = conv3x3(in_chan, out_chan)
        self.bn = nn.BatchNorm2d(out_chan)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        atten = torch.cat([avg_out, max_out], dim=1)
        atten = torch.sigmoid(self.conv_atten(atten))
        out = torch.mul(x, atten)
        out = F.relu(self.bn(self.conv(out)), inplace=True)
        return out

    def initialize(self):
        weight_init(self)
# Cross Aggregation Module
class CAM(nn.Module):
    def __init__(self, channel):
        super(CAM, self).__init__()
        self.down = nn.Sequential(
            conv3x3(channel, channel, stride=2),
            nn.BatchNorm2d(channel)
        )
        self.conv_1 = conv3x3(channel, channel)
        self.bn_1 = nn.BatchNorm2d(channel)
        self.conv_2 = conv3x3(channel, channel)
        self.bn_2 = nn.BatchNorm2d(channel)
        self.mul = FFM(channel)

    def forward(self, x_high, x_low):
        left_1 = x_low
        left_2 = F.relu(self.down(x_low), inplace=True)
        right_1 = F.interpolate(x_high, size=x_low.size()[2:], mode='bilinear', align_corners=True)
        right_2 = x_high
        left = F.relu(self.bn_1(self.conv_1(left_1 * right_1)), inplace=True)
        right = F.relu(self.bn_2(self.conv_2(left_2 * right_2)), inplace=True)
        # left = F.relu(left_1 * right_1, inplace=True)
        # right = F.relu(left_2 * right_2, inplace=True)
        right = F.interpolate(right, size=x_low.size()[2:], mode='bilinear', align_corners=True)
        out = self.mul(left, right)
        return out

    def initialize(self):
        weight_init(self)

Boundary Path

最终"finer"阶段:底层局部特征+高层位置特征(额外的边界监督),获得显著目标的清晰边界

  • 下采样4倍的特征:更好的边界信息,但也存在非显著目标边界的噪声干扰
  • 引入高层特征:高层语义的定位信息作为指导,增强显著目标边界,抑制非显著目标边界

  • BRM的融合策略:上采样Dp12至Dp3的分辨率并相加,全局池化+conv1x1+sigmoid获得通道注意力来reweight相加后的特征,融合(相加)并增强(两层卷积)相加后与reweight后的特征

# Boundary Refinement Module
class BRM(nn.Module):
    def __init__(self, channel):
        super(BRM, self).__init__()
        self.conv_atten = conv1x1(channel, channel)
        self.conv_1 = conv3x3(channel, channel)
        self.bn_1 = nn.BatchNorm2d(channel)
        self.conv_2 = conv3x3(channel, channel)
        self.bn_2 = nn.BatchNorm2d(channel)

    def forward(self, x_1, x_edge):
        # x = torch.cat([x_1, x_edge], dim=1)
        x = x_1 + x_edge
        atten = F.avg_pool2d(x, x.size()[2:])
        atten = torch.sigmoid(self.conv_atten(atten))
        out = torch.mul(x, atten) + x
        out = F.relu(self.bn_1(self.conv_1(out)), inplace=True)
        out = F.relu(self.bn_2(self.conv_2(out)), inplace=True)
        return out

    def initialize(self):
        weight_init(self)

CTDNet整体模型定义: 

class CTDNet(nn.Module):
    def __init__(self, cfg):
        super(CTDNet, self).__init__()
        self.cfg = cfg
        block = BasicBlock
        self.bkbone = ResNet(block, [2, 2, 2, 2])

        self.path1_1 = nn.Sequential(
            conv1x1(512 * block.expansion, 64),
            nn.BatchNorm2d(64)
        )
        self.path1_2 = nn.Sequential(
            conv1x1(512 * block.expansion, 64),
            nn.BatchNorm2d(64)
        )
        self.path1_3 = nn.Sequential(
            conv1x1(256 * block.expansion, 64),
            nn.BatchNorm2d(64)
        )

        self.path2 = SAM(128 * block.expansion, 64)

        self.path3 = nn.Sequential(
            conv1x1(64 * block.expansion, 64),
            nn.BatchNorm2d(64)
        )

        self.fuse1_1 = FFM(64)
        self.fuse1_2 = FFM(64)
        self.fuse12 = CAM(64)
        self.fuse3 = FFM(64)
        self.fuse23 = BRM(64)

        self.head_1 = conv3x3(64, 1, bias=True)
        self.head_2 = conv3x3(64, 1, bias=True)
        self.head_3 = conv3x3(64, 1, bias=True)
        self.head_4 = conv3x3(64, 1, bias=True)
        self.head_5 = conv3x3(64, 1, bias=True)
        self.head_edge = conv3x3(64, 1, bias=True)

        self.initialize()

    def forward(self, x, shape=None):
        shape = x.size()[2:] if shape is None else shape
        l1, l2, l3, l4, l5 = self.bkbone(x)

        # 公式(1): Up(GAP(E5))
        path1_1 = F.avg_pool2d(l5, l5.size()[2:])
        path1_1 = self.path1_1(path1_1)
        path1_1 = F.interpolate(path1_1, size=l5.size()[2:], mode='bilinear', align_corners=True)   # 1/32
        # 公式(1): F1x1(E5)
        path1_2 = F.relu(self.path1_2(l5), inplace=True)                                            # 1/32
        # 公式(1): FFM1
        path1_2 = self.fuse1_1(path1_1, path1_2)                                                    # 1/32
        # 公式(1): Up(FFM1)
        path1_2 = F.interpolate(path1_2, size=l4.size()[2:], mode='bilinear', align_corners=True)   # 1/16
        # 公式(1): F1x1(E4)
        path1_3 = F.relu(self.path1_3(l4), inplace=True)                                            # 1/16
        # 公式(1): FFM2
        path1 = self.fuse1_2(path1_2, path1_3)                                                      # 1/16
        # path1 = F.interpolate(path1, size=l3.size()[2:], mode='bilinear', align_corners=True)

        path2 = self.path2(l3)                                                                      # 1/8
        path12 = self.fuse12(path1, path2)                                                          # 1/8
        path12 = F.interpolate(path12, size=l2.size()[2:], mode='bilinear', align_corners=True)     # 1/4

        path3_1 = F.relu(self.path3(l2), inplace=True)                                              # 1/4
        path3_2 = F.interpolate(path1_2, size=l2.size()[2:], mode='bilinear', align_corners=True)   # 1/4
        path3 = self.fuse3(path3_1, path3_2)                                                        # 1/4

        path_out = self.fuse23(path12, path3)                                                       # 1/4

        logits_1 = F.interpolate(self.head_1(path_out), size=shape, mode='bilinear', align_corners=True)
        logits_edge = F.interpolate(self.head_edge(path3), size=shape, mode='bilinear', align_corners=True)

        if self.cfg.mode == 'train':
            logits_2 = F.interpolate(self.head_2(path12), size=shape, mode='bilinear', align_corners=True)
            logits_3 = F.interpolate(self.head_3(path1), size=shape, mode='bilinear', align_corners=True)
            logits_4 = F.interpolate(self.head_4(path1_2), size=shape, mode='bilinear', align_corners=True)
            logits_5 = F.interpolate(self.head_5(path1_1), size=shape, mode='bilinear', align_corners=True)
            return logits_1, logits_edge, logits_2, logits_3, logits_4, logits_5
        else:
            return logits_1, logits_edge

    def initialize(self):
        if self.cfg.snapshot:
            self.load_state_dict(torch.load(self.cfg.snapshot))
        else:
            weight_init(self)

 Loss

SOD任务通常使用BCE loss(模型预测mask和ground-truth的逐像素误差)或IoU loss(结构相似度而非单个像素):

 模型共输出六个mask,其中Dp3仅使用BCE loss,其余Dp123/Dp12/Dp1/Eg5/E6均使用IoU loss + 0.6*BCE loss

def total_loss(pred, mask):
    pred = torch.sigmoid(pred)
    bce_loss = nn.BCELoss()
    bce = bce_loss(pred, mask)

    inter = (pred * mask).sum(dim=(2, 3))
    union = (pred + mask).sum(dim=(2, 3))
    iou = 1 - (inter+1)/(union-inter+1)
    iou = iou.mean()
    return iou + 0.6*bce
            out1, out_edge, out2, out3, out4, out5 = net(image)
            loss1 = total_loss(out1, mask)
            loss_edge = bce_loss(out_edge, edge)
            loss2 = total_loss(out2, mask)
            loss3 = total_loss(out3, mask)
            loss4 = total_loss(out4, mask)
            loss5 = total_loss(out5, mask)
            loss = loss1 + loss_edge + loss2/2 + loss3/4 + loss4/8 + loss5/16

⚗️Experiment Keypoints

  • “coarse-fine-finer”三个阶段的可视化

  • Dataset:
    ECSSD (1,000)PASCAL-S (850)DUTS (15,552)HKU-IS (4,447)DUT-OMRON (5,168)
  • Metrics:Mean Absolute Error (MAE)-像素级平均绝对误差, F-measure-准确+召回 and E-measure-结构相似度
  • 数据前处理:训练阶段resize-352×352 with 随机剪裁+随机水平翻转,推理阶段直接resize
  • 优化器:SGD with momentum of 0.9 + weight decay of 5e-4
  • batch size 32 + epoch 40 (1080Ti)
  • 其他训练策略:warm-up + linear decay学习率 with lr最大值 5e-3 for 预训练 backbone and 5e-2 for 其他网络
# enc_params为backbone网络,dec_params为其他网络
optimizer = torch.optim.SGD([{'params': enc_params}, {'params': dec_params}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True)

for epoch in range(cfg.epoch):
    # 学习率的linear decay
    optimizer.param_groups[0]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr*0.1
    optimizer.param_groups[1]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr

📖Background Knowledge

  • 下采样n倍 = 原图分辨率的1/n
  • 底层 = 浅层(Stage-小),高层 = 深层(Stage-大)
  • 视觉注意力机制:通道注意力(CAM)、空间注意力(SAM)、混合注意力(CBAM)、自注意力(ASPP)
  • U-shape structure (e.g. FCN, SegNet, Unet)
  • salient object detection (SOD)
    • EGNet: a strong backbone (e.g., ResNet-50 or ResNet-101) + a complicated decode

          paper:[1908.08297] EGNet:Edge Guidance Network for Salient Object Detection (arxiv.org)

  • U-shape: Pi-CANet, TDBU, ASNet, MINet, DASNet, PFSNet
  • efficient: PoolNet, CPD, ITSDNet
  • 改进U-shape边界不准问题 (introducing an additional boundary-aware branch or a boundary-aware loss function): C2SNet, BASNet, EGNet, BANet, PAGE, AFNet, SCRN
  • 思路相似:BiSeNet系列,PIDNet (CVPR'2023)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值