第一步:准备数据
海龟实例分割数据,总共有8729张图片,里面的像素值为0、1、2和3,所以看起来全部是黑的,不影响使用
第二步:搭建模型
DeepLabV3+的网络结构如下图所示,主要为Encoder-Decoder结构。其中,Encoder为改进的DeepLabV3,Decoder为3+版本新提出的。
1.1、Encoder
在Encoder部分,主要包括了backbone(即:图1中的DCNN)、ASPP两大部分。
其中backbone有两种网络结构:将layer4改为空洞卷积的Resnet系列、改进的Xception。从backbone出来的feature map分两部分:一部分是最后一层卷积输出的feature maps,另一部分是中间的低级特征的feature maps;backbone输出的第一部分送入ASPP模块,第二部分则送入Decoder模块。
ASPP模块接受backbone的第一部分输出作为输入,使用了四种不同膨胀率的空洞卷积块(包括卷积、BN、激活层)和一个全局平均池化块(包括池化、卷积、BN、激活层)得到一共五组feature maps,将其concat起来之后,经过一个1*1卷积块(包括卷积、BN、激活、dropout层),最后送入Decoder模块。
1.2、Decoder
在Decoder部分,接收来自backbone中间层的低级feature maps和来自ASPP模块的输出作为输入。
首先,对低级feature maps使用1*1卷积进行通道降维,从256降到48(之所以需要降采样到48,是因为太多的通道会掩盖ASPP输出的feature maps的重要性,且实验验证48最佳);
然后,对来自ASPP的feature maps进行插值上采样,得到与低级featuremaps尺寸相同的feature maps;
接着,将通道降维的低级feature maps和线性插值上采样得到的feature maps使用concat拼接起来,并送入一组3*3卷积块进行处理;
最后,再次进行线性插值上采样,得到与原图分辨率大小一样的预测图。
第三步:代码
1)损失函数为:交叉熵损失函数
2)网络代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2
class MobileNetV2(nn.Module):
def __init__(self, downsample_factor=8, pretrained=True):
super(MobileNetV2, self).__init__()
from functools import partial