第一步:准备数据
X射线图像牙齿分割,总共有2000张
第二步:搭建模型
UNet3+主要是参考了UNet和UNet++两个网络结构。尽管UNet++采用了嵌套和密集跳过连接的网络结构(见图1(b)红色三角区域),但是它没有直接从多尺度信息中提取足够多的信息。此部分,在我理解而言UNet++虽然名义上通过嵌套和密集跳过连接进行了多尺度信息的利用,但是从本质上看基本都是短连接,基本上都对解码特征进行了再次处理,再加上各个连接的融合,多尺度信息的原始特征几乎没有得到特别好的利用,信号处理有些矫枉过正或是丢失。UNet3+利用了全尺度的跳跃连接(skip connection)和深度监督(deep supervisions)。全尺度的跳跃连接把来自不同尺度特征图中的高级语义与低级语义直接结合(当然需要必要的上采样操作);而深度监督则从多尺度聚合的特征图中学习层次表示。注意一点:UNet++和UNet3+都用到了深度监督,但是监督的位置是完全不一样的,从图1(b)、(c)中的Sup部分可以清楚的看到不同之处。
在UNet3+中,可以从全尺度捕获细粒度的细节和粗粒度的语义。为了进一步从全尺寸的聚合特征图中学习层次表示法,每个边的输出都与一个混合损失函数相连接,这有助于精确分割,特别是对于在医学图像体积中出现不同尺度的器官。从图1中也可以看出,UNet3+的参数量明显小于UNet++。
第三步:代码
1)损失函数为:交叉熵损失函数
2)网络代码:
class UNet3Plus(nn.Module):
def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4,
is_deconv=True, is_batchnorm=True):
super(UNet3Plus, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.feature_scale = feature_scale
self.is_deconv = is_deconv
self.is_batchnorm = is_batchnorm
filters = [16, 32, 64, 128, 256]
## -------------Encoder--------------
self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
## -------------Decoder--------------
self.CatChannels = filters[0]
self.CatBlocks = 5
self.UpChannels = self.CatChannels * self.CatBlocks
'''stage 4d'''
# h1->320*320, hd4->40*40, Pooling 8 times
self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd4_relu = nn.ReLU(inplace=True)
# h2->160*160, hd4->40*40, Pooling 4 times
self.h2_PT_h