【医学分割】unet3+

概述

unet3+是在unet以及unet++的基础上提出来的,unet的核心是skip-connection。
而unet++也正是在这个核心上做的改动,使用了重叠的稠密卷积代替粗暴的特征融合。
而unet3+注意到,unet和unet++没有直接从多尺度信息中提取足够多的信息,基于这一点,设计了一种新的skip-connection的结构,更好的将低级细粒度信息和高级语义特征进行了融合,并且这种结构的参数量会比unet和unet++都少。同时,对于decoder的输出进行深度监督,并且提出了一种新的损失函数进行训练。另一方面,使用分类做了指导,降低在背景图中过度分割的情况。
unet和unet++没有直接从多尺度信息中提取足够多的信息指的是,unet是根据对应层的encoder的输出以及下一层decoder上采样的结果构造当前层decoder的,没有进行了多尺度信息的利用;而unet++虽然通过嵌套和密集跳过连接进行了多尺度信息的利用,但是更像是对于encoder特征的不断处理,而不是对原始特征的利用;unet3+则是对原始特征进行了多尺度信息的利用。
以下就是三种结构的比较示意图:
在这里插入图片描述

细节

Full-scale Skip Connections

在这里插入图片描述
unet3+还是在unet基础上改动得到的,故他依然是encoder-decoder的结构。encoder部分其实unet,unet++还有unet3+都是相同的,关键是decoder部分是怎么得到的。unet3+中的做法是,encoder中层数小于等于当前层的特征图经过池化(无重叠最大池化)和卷积操作得到64个通道的特征图(当然,层数相同的那一层不需要池化操作),然后decoder中层数大于当前层的特征图经过上采样(双线性插值)和卷积同样得到64个通道的特征图,接着将这些特征图concat起来,以作者采用的5层结构为例,总共会有64x5=320,这320个通道数的特征图,经过卷积、正则化、和激活函数之后就构成了decoder的一层了。
上述过程的形式化表述就是:
在这里插入图片描述
并且,有一个不得不提的是,尽管unet3+的结构相对于unet会复杂,但是在保证相同的encoder的基础上,参数量反而会少。也就是保证相同的encoder,参数量排名是:unet3+<unet<unet++

Full-scale DeepSupervision

unet3+和unet++都采用深度监督,从多尺度聚合的特征图中学习层次表示,但是监督的位置信息不同哟,这点从图中的sup位置就能发现。unet++对第一层特征图进行监督,具体的操作是在这些特征图的后面都接c(C是类别个数)个1x1的卷积核和一个sigmoid函数,得到最后的分割结果;而unet3+则是对decoder的每一个层对应一个侧输出,这个输出经过3x3的卷积、上采样和sigmoid函数,得到分割的结果。其中上采样是为了得到全尺寸的特征图,使用相同的gt进行监督。

在这里插入图片描述

开发了一种混合损失用于在三级层次结构(像素级、补丁级和图片级)中进行分割,它能捕获大尺度的和精细结构清晰的边界,损失函数:是MS-SSIM loss+focal loss+ IoU loss,其中

  • 多尺度结构相似性指数 (MM-SSIM) 损失用于为模糊边界分配更高的权重。
  • Focal loss起源于RetinaNet,用于处理类不平衡问题。
  • 使用标准IoU 损失。

Classification-guided Module

在这里插入图片描述
在encoder的最后一层或者说是decoder的最后一层,经过dropout、1x1卷积、最大池化和sigmoid函数得到两个值,然后经过argmax函数转换成单个输出,表示有或者没有目标,然后我们将这个输出与侧分割相乘,得到更加准确的分割结果。解决由于较浅层中背景噪声等残留信息导致过分割的现象

简答实现

import paddle
import paddle.nn as nn


# 两次卷积操作
# 卷积计算公式:
# 输出大小 = (输入大小 − Filter + 2Padding )/Stride+1
class VGGBlock(nn.Layer):
    def __init__(self,in_channels,out_channels):
        super(VGGBlock, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2D(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2D(out_channels),
            nn.LeakyReLU(),

            nn.Conv2D(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2D(out_channels),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)

# 将decoder上采样若干倍数 然后经过3x3卷积得到64个通道的特征图
class Up(nn.Layer):
    def __init__(self,scale_factor,in_channels):
        super(Up, self).__init__()
        self.layer=nn.Sequential(
            nn.UpsamplingBilinear2D(scale_factor=scale_factor),
            nn.Conv2D(in_channels,64,3,1,1)
        )

    def forward(self,x):
        return self.layer(x)

# 将encoder池化若干倍数 然后经过3x3卷积得到64个通道的特征图
# 同层的话 就不需要池化操作了
class Down(nn.Layer):
    def __init__(self,kernel_size,in_channels):
        super(Down, self).__init__()
        if kernel_size!=0:
            self.layer = nn.Sequential(
                nn.MaxPool2D(kernel_size),
                nn.Conv2D(in_channels, 64, 3, 1, 1)
            )
        else:
            self.layer = nn.Sequential(
                nn.Conv2D(in_channels, 64, 3, 1, 1)
            )

    def forward(self,x):
        return self.layer(x)


class UNet3Plus(nn.Layer):
    def __init__(self,num_classes=2,deep_supervision=False):
        super(UNet3Plus, self).__init__()

        filters=[64, 128, 256, 512, 1024]
        self.pool = nn.MaxPool2D(2)
        self.num_classes=num_classes
        self.deep_supervision=deep_supervision

        ## -------------encoder-------------
        self.encoder_1=VGGBlock(3,filters[0])
        self.encoder_2=VGGBlock(filters[0],filters[1])
        self.encoder_3=VGGBlock(filters[1],filters[2])
        self.encoder_4=VGGBlock(filters[2],filters[3])
        self.encoder_5=VGGBlock(filters[3],filters[4])

        ## -------------decoder-------------

        # 这是将concat得到的320个通道的特征图做处理 得到decoder当前层
        self.decoder = VGGBlock(filters[0] * 5, filters[0] * 5)

        # decoder4
        # 因为decoder5和encoder5就是同一个 为了后面统一 就写成decoder5了
        # 除了decoder5之外 剩下的decoder每层通道数都是320
        self.decoder_4_encoder_1=Down(8,filters[0])
        self.decoder_4_encoder_2=Down(4,filters[1])
        self.decoder_4_encoder_3=Down(2,filters[2])
        self.decoder_4_encoder_4=Down(0,filters[3])
        self.decoder_4_decoder_5=Up(2,filters[4])

        # decoder3
        self.decoder_3_encoder_1 = Down(4, filters[0])
        self.decoder_3_encoder_2 = Down(2, filters[1])
        self.decoder_3_encoder_3 = Down(0, filters[2])
        self.decoder_3_decoder_4 = Up(2, filters[0]*5)
        self.decoder_3_decoder_5 = Up(4, filters[4])

        # decoder2
        self.decoder_2_encoder_1 = Down(2, filters[0])
        self.decoder_2_encoder_2 = Down(0, filters[1])
        self.decoder_2_decoder_3 = Up(2, filters[0]*5)
        self.decoder_2_decoder_4 = Up(4, filters[0]*5)
        self.decoder_2_decoder_5 = Up(8, filters[4])

        # decoder1
        self.decoder_1_encoder_1 = Down(0, filters[0])
        self.decoder_1_decoder_2 = Up(2, filters[0]*5)
        self.decoder_1_decoder_3 = Up(4, filters[0]*5)
        self.decoder_1_decoder_4 = Up(8, filters[0]*5)
        self.decoder_1_decoder_5 = Up(16, filters[4])

        if self.deep_supervision:
            self.final1=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
            self.final2=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
            self.final3=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
            self.final4=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
        else:
            self.final = nn.Conv2D(filters[0]*5, self.num_classes, 3, 1, 1)




    def forward(self,x):
        ## -------------encoder-------------
        encoder_1=self.encoder_1(x)
        encoder_2=self.encoder_2(self.pool(encoder_1))
        encoder_3=self.encoder_3(self.pool(encoder_2))
        encoder_4=self.encoder_4(self.pool(encoder_3))
        encoder_5=self.encoder_5(self.pool(encoder_4))

        ## -------------decoder-------------

        # decoder4
        decoder_4_encoder_1=self.decoder_4_encoder_1(encoder_1)
        decoder_4_encoder_2=self.decoder_4_encoder_2(encoder_2)
        decoder_4_encoder_3=self.decoder_4_encoder_3(encoder_3)
        decoder_4_encoder_4=self.decoder_4_encoder_4(encoder_4)
        decoder_4_decoder_5=self.decoder_4_decoder_5(encoder_5)
        decoder_4=self.decoder(paddle.concat([decoder_4_decoder_5,decoder_4_encoder_4,decoder_4_encoder_3,decoder_4_encoder_2,decoder_4_encoder_1],axis=1))

        # decoder3
        decoder_3_encoder_1 = self.decoder_3_encoder_1(encoder_1)
        decoder_3_encoder_2 = self.decoder_3_encoder_2(encoder_2)
        decoder_3_encoder_3 = self.decoder_3_encoder_3(encoder_3)
        decoder_3_decoder_4 = self.decoder_3_decoder_4(decoder_4)
        decoder_3_decoder_5 = self.decoder_3_decoder_5(encoder_5)
        decoder_3=self.decoder(paddle.concat([decoder_3_decoder_5,decoder_3_decoder_4,decoder_3_encoder_3,decoder_3_encoder_2,decoder_3_encoder_1],axis=1))

        # decoder2
        decoder_2_encoder_1 = self.decoder_2_encoder_1(encoder_1)
        decoder_2_encoder_2 = self.decoder_2_encoder_2(encoder_2)
        decoder_2_decoder_3 = self.decoder_2_decoder_3(decoder_3)
        decoder_2_decoder_4 = self.decoder_2_decoder_4(decoder_4)
        decoder_2_decoder_5 = self.decoder_2_decoder_5(encoder_5)
        decoder_2=self.decoder(paddle.concat([decoder_2_decoder_5,decoder_2_decoder_4,decoder_2_decoder_3,decoder_2_encoder_2,decoder_2_encoder_1],axis=1))

        # decoder2
        decoder_1_encoder_1 = self.decoder_1_encoder_1(encoder_1)
        decoder_1_decoder_2 = self.decoder_1_decoder_2(decoder_2)
        decoder_1_decoder_3 = self.decoder_1_decoder_3(decoder_3)
        decoder_1_decoder_4 = self.decoder_1_decoder_4(decoder_4)
        decoder_1_decoder_5 = self.decoder_1_decoder_5(encoder_5)
        decoder_1=self.decoder(paddle.concat([decoder_1_decoder_5,decoder_1_decoder_4,decoder_1_decoder_3,decoder_1_decoder_2,decoder_1_encoder_1],axis=1))


        if self.deep_supervision:
            output1 = self.final1(decoder_1)
            output2 = self.final2(decoder_2)
            output3 = self.final3(decoder_3)
            output4 = self.final4(decoder_4)
            return [output1, output2, output3, output4]

        else:
            output = self.final(decoder_1)
            return output

if __name__ == '__main__':
    # x=paddle.randn(shape=[2,3,256,256])
    unet=UNet3Plus()
    # print(net(x).shape)
    paddle.summary(unet, (1,3,256,256))
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值