概述
unet3+是在unet以及unet++的基础上提出来的,unet的核心是skip-connection。
而unet++也正是在这个核心上做的改动,使用了重叠的稠密卷积代替粗暴的特征融合。
而unet3+注意到,unet和unet++没有直接从多尺度信息中提取足够多的信息,基于这一点,设计了一种新的skip-connection的结构,更好的将低级细粒度信息和高级语义特征进行了融合,并且这种结构的参数量会比unet和unet++都少。同时,对于decoder的输出进行深度监督,并且提出了一种新的损失函数进行训练。另一方面,使用分类做了指导,降低在背景图中过度分割的情况。
注
:unet和unet++没有直接从多尺度信息中提取足够多的信息
指的是,unet是根据对应层的encoder的输出以及下一层decoder上采样的结果构造当前层decoder的,没有进行了多尺度信息的利用;而unet++虽然通过嵌套和密集跳过连接进行了多尺度信息的利用,但是更像是对于encoder特征的不断处理,而不是对原始特征的利用;unet3+则是对原始特征进行了多尺度信息的利用。
以下就是三种结构的比较示意图:
细节
Full-scale Skip Connections
unet3+还是在unet基础上改动得到的,故他依然是encoder-decoder的结构。encoder部分其实unet,unet++还有unet3+都是相同的,关键是decoder部分是怎么得到的。unet3+中的做法是,encoder中层数小于等于当前层的特征图经过池化(无重叠最大池化)和卷积操作得到64个通道的特征图(当然,层数相同的那一层不需要池化操作),然后decoder中层数大于当前层的特征图经过上采样(双线性插值)和卷积同样得到64个通道的特征图,接着将这些特征图concat起来,以作者采用的5层结构为例,总共会有64x5=320,这320个通道数的特征图,经过卷积、正则化、和激活函数之后就构成了decoder的一层了。
上述过程的形式化表述就是:
并且,有一个不得不提的是,尽管unet3+的结构相对于unet会复杂,但是在保证相同的encoder的基础上,参数量反而会少。也就是保证相同的encoder,参数量排名是:unet3+<unet<unet++
Full-scale DeepSupervision
unet3+和unet++都采用深度监督,从多尺度聚合的特征图中学习层次表示,但是监督的位置信息不同哟,这点从图中的sup
位置就能发现。unet++对第一层特征图进行监督,具体的操作是在这些特征图的后面都接c(C是类别个数)个1x1的卷积核和一个sigmoid函数,得到最后的分割结果;而unet3+则是对decoder的每一个层对应一个侧输出,这个输出经过3x3的卷积、上采样和sigmoid函数,得到分割的结果。其中上采样是为了得到全尺寸的特征图,使用相同的gt进行监督。
开发了一种混合损失用于在三级层次结构(像素级、补丁级和图片级)中进行分割,它能捕获大尺度的和精细结构清晰的边界,损失函数:是MS-SSIM loss+focal loss+ IoU loss,其中
- 多尺度结构相似性指数 (MM-SSIM) 损失用于为模糊边界分配更高的权重。
- Focal loss起源于RetinaNet,用于处理类不平衡问题。
- 使用标准IoU 损失。
Classification-guided Module
在encoder的最后一层或者说是decoder的最后一层,经过dropout、1x1卷积、最大池化和sigmoid函数得到两个值,然后经过argmax函数转换成单个输出,表示有或者没有目标,然后我们将这个输出与侧分割相乘,得到更加准确的分割结果。解决由于较浅层中背景噪声等残留信息导致过分割的现象
简答实现
import paddle
import paddle.nn as nn
# 两次卷积操作
# 卷积计算公式:
# 输出大小 = (输入大小 − Filter + 2Padding )/Stride+1
class VGGBlock(nn.Layer):
def __init__(self,in_channels,out_channels):
super(VGGBlock, self).__init__()
self.layer=nn.Sequential(
nn.Conv2D(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU(),
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
# 将decoder上采样若干倍数 然后经过3x3卷积得到64个通道的特征图
class Up(nn.Layer):
def __init__(self,scale_factor,in_channels):
super(Up, self).__init__()
self.layer=nn.Sequential(
nn.UpsamplingBilinear2D(scale_factor=scale_factor),
nn.Conv2D(in_channels,64,3,1,1)
)
def forward(self,x):
return self.layer(x)
# 将encoder池化若干倍数 然后经过3x3卷积得到64个通道的特征图
# 同层的话 就不需要池化操作了
class Down(nn.Layer):
def __init__(self,kernel_size,in_channels):
super(Down, self).__init__()
if kernel_size!=0:
self.layer = nn.Sequential(
nn.MaxPool2D(kernel_size),
nn.Conv2D(in_channels, 64, 3, 1, 1)
)
else:
self.layer = nn.Sequential(
nn.Conv2D(in_channels, 64, 3, 1, 1)
)
def forward(self,x):
return self.layer(x)
class UNet3Plus(nn.Layer):
def __init__(self,num_classes=2,deep_supervision=False):
super(UNet3Plus, self).__init__()
filters=[64, 128, 256, 512, 1024]
self.pool = nn.MaxPool2D(2)
self.num_classes=num_classes
self.deep_supervision=deep_supervision
## -------------encoder-------------
self.encoder_1=VGGBlock(3,filters[0])
self.encoder_2=VGGBlock(filters[0],filters[1])
self.encoder_3=VGGBlock(filters[1],filters[2])
self.encoder_4=VGGBlock(filters[2],filters[3])
self.encoder_5=VGGBlock(filters[3],filters[4])
## -------------decoder-------------
# 这是将concat得到的320个通道的特征图做处理 得到decoder当前层
self.decoder = VGGBlock(filters[0] * 5, filters[0] * 5)
# decoder4
# 因为decoder5和encoder5就是同一个 为了后面统一 就写成decoder5了
# 除了decoder5之外 剩下的decoder每层通道数都是320
self.decoder_4_encoder_1=Down(8,filters[0])
self.decoder_4_encoder_2=Down(4,filters[1])
self.decoder_4_encoder_3=Down(2,filters[2])
self.decoder_4_encoder_4=Down(0,filters[3])
self.decoder_4_decoder_5=Up(2,filters[4])
# decoder3
self.decoder_3_encoder_1 = Down(4, filters[0])
self.decoder_3_encoder_2 = Down(2, filters[1])
self.decoder_3_encoder_3 = Down(0, filters[2])
self.decoder_3_decoder_4 = Up(2, filters[0]*5)
self.decoder_3_decoder_5 = Up(4, filters[4])
# decoder2
self.decoder_2_encoder_1 = Down(2, filters[0])
self.decoder_2_encoder_2 = Down(0, filters[1])
self.decoder_2_decoder_3 = Up(2, filters[0]*5)
self.decoder_2_decoder_4 = Up(4, filters[0]*5)
self.decoder_2_decoder_5 = Up(8, filters[4])
# decoder1
self.decoder_1_encoder_1 = Down(0, filters[0])
self.decoder_1_decoder_2 = Up(2, filters[0]*5)
self.decoder_1_decoder_3 = Up(4, filters[0]*5)
self.decoder_1_decoder_4 = Up(8, filters[0]*5)
self.decoder_1_decoder_5 = Up(16, filters[4])
if self.deep_supervision:
self.final1=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
self.final2=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
self.final3=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
self.final4=nn.Conv2D(filters[0]*5,self.num_classes,3,1,1)
else:
self.final = nn.Conv2D(filters[0]*5, self.num_classes, 3, 1, 1)
def forward(self,x):
## -------------encoder-------------
encoder_1=self.encoder_1(x)
encoder_2=self.encoder_2(self.pool(encoder_1))
encoder_3=self.encoder_3(self.pool(encoder_2))
encoder_4=self.encoder_4(self.pool(encoder_3))
encoder_5=self.encoder_5(self.pool(encoder_4))
## -------------decoder-------------
# decoder4
decoder_4_encoder_1=self.decoder_4_encoder_1(encoder_1)
decoder_4_encoder_2=self.decoder_4_encoder_2(encoder_2)
decoder_4_encoder_3=self.decoder_4_encoder_3(encoder_3)
decoder_4_encoder_4=self.decoder_4_encoder_4(encoder_4)
decoder_4_decoder_5=self.decoder_4_decoder_5(encoder_5)
decoder_4=self.decoder(paddle.concat([decoder_4_decoder_5,decoder_4_encoder_4,decoder_4_encoder_3,decoder_4_encoder_2,decoder_4_encoder_1],axis=1))
# decoder3
decoder_3_encoder_1 = self.decoder_3_encoder_1(encoder_1)
decoder_3_encoder_2 = self.decoder_3_encoder_2(encoder_2)
decoder_3_encoder_3 = self.decoder_3_encoder_3(encoder_3)
decoder_3_decoder_4 = self.decoder_3_decoder_4(decoder_4)
decoder_3_decoder_5 = self.decoder_3_decoder_5(encoder_5)
decoder_3=self.decoder(paddle.concat([decoder_3_decoder_5,decoder_3_decoder_4,decoder_3_encoder_3,decoder_3_encoder_2,decoder_3_encoder_1],axis=1))
# decoder2
decoder_2_encoder_1 = self.decoder_2_encoder_1(encoder_1)
decoder_2_encoder_2 = self.decoder_2_encoder_2(encoder_2)
decoder_2_decoder_3 = self.decoder_2_decoder_3(decoder_3)
decoder_2_decoder_4 = self.decoder_2_decoder_4(decoder_4)
decoder_2_decoder_5 = self.decoder_2_decoder_5(encoder_5)
decoder_2=self.decoder(paddle.concat([decoder_2_decoder_5,decoder_2_decoder_4,decoder_2_decoder_3,decoder_2_encoder_2,decoder_2_encoder_1],axis=1))
# decoder2
decoder_1_encoder_1 = self.decoder_1_encoder_1(encoder_1)
decoder_1_decoder_2 = self.decoder_1_decoder_2(decoder_2)
decoder_1_decoder_3 = self.decoder_1_decoder_3(decoder_3)
decoder_1_decoder_4 = self.decoder_1_decoder_4(decoder_4)
decoder_1_decoder_5 = self.decoder_1_decoder_5(encoder_5)
decoder_1=self.decoder(paddle.concat([decoder_1_decoder_5,decoder_1_decoder_4,decoder_1_decoder_3,decoder_1_decoder_2,decoder_1_encoder_1],axis=1))
if self.deep_supervision:
output1 = self.final1(decoder_1)
output2 = self.final2(decoder_2)
output3 = self.final3(decoder_3)
output4 = self.final4(decoder_4)
return [output1, output2, output3, output4]
else:
output = self.final(decoder_1)
return output
if __name__ == '__main__':
# x=paddle.randn(shape=[2,3,256,256])
unet=UNet3Plus()
# print(net(x).shape)
paddle.summary(unet, (1,3,256,256))