语义分割学习总结(二)—— Unet网络

目录

一、网络结构

(一)左半部分(特征提取部分)

(二)右半部分(特征融合部分)

(三)代码实现

(二)重叠平铺策略

(三)加权损失

(四)随机弹性形变


一、网络结构

(图源来自网络)

这个结构的思想其实就是先对图像进行卷积+池化,进行特征提取,也就是U型的左半部分,然后对图像拼接+上采样,进行特征融合。 

(一)左半部分(特征提取部分)

两个3x3的卷积层(ReLU)+ 一个2x2的maxpooling层构成一个下采样的模块,由下采样模块反复组成。每经过一次下采样,通道数翻倍。论文中用的是valid卷积(当filter全部在image里面的时候才开始进行卷积运算),因此每做一次valid卷积,由于没有padding,feature map的height和width会分别减少3-1=2个像素。

(二)右半部分(特征融合部分)

一个2x2的上采样卷积层(ReLU)+Concatenation(先crop对应左半部分输出的feature map然后与右半部分上采样结果相加)+2个3x3的卷积层(ReLU)反复构成,最后一层通过一个1x1卷积将通道数变成期望的类别数(论文中的channel2分别为前景和背景的mask,医学中就是细胞区域和黑色背景区域)。每一次上采样转置卷积之后,height和width都加倍,同时channel减半,用于和左侧的浅层feature map进行合并拼接。Unet相比更早提出的FCN网络,使用通道拼接来作为特征图的融合方式。主要好处是,浅层卷积关注纹理特征,深层网络关注更深更本质的特征,将浅层网络提取的特征和深层网络提取的特征融合可以使得特征“厚且广”,还有一个原因我认为是下采样操作会导致高频信息丢失,从而导致边缘的特征丢失,而上采样虽然能够获得更大的特征图,但是并不能对进行过下采样的特征图进行恢复,因此是缺少信息的,通过这种特征拼接多少可以找回一些丢失的边缘信息。

(三)代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class double_conv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1):
        super(double_conv2d_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,
                                stride = strides,padding = padding ,bias =True)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size = kernel_size,
                                stride = strides,padding = padding, bias = True)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out

class deconv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=2,strides=2):
        super(deconv2d_bn,self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
                                        kernel_size= kernel_size,
                                        stride = strides,bias = True)
        self.bn1 = nn.BatchNorm2d(out_channels)

    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out

class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.layer1_conv = double_conv2d_bn(1,8)
        self.layer2_conv = double_conv2d_bn(8,16)
        self.layer3_conv = double_conv2d_bn(16,32)
        self.layer4_conv = double_conv2d_bn(32,64)
        self.layer5_conv = double_conv2d_bn(64,128)
        self.layer6_conv = double_conv2d_bn(128,64)
        self.layer7_conv = double_conv2d_bn(64,32)
        self.layer8_conv = double_conv2d_bn(32,16)
        self.layer9_conv = double_conv2d_bn(16,8)
        self.layer10_conv = nn.Conv2d(8,1,kernel_size=3,
                                        stride = 1,padding =1,bias = True)
        
        self.deconv1 = deconv2d_bn(128,64)
        self.deconv2 = deconv2d_bn(64,32)
        self.deconv3 = deconv2d_bn(32,16)
        self.deconv4 = deconv2d_bn(16,8)

        self.sigmoid = nn.Sigmoid()

    def forward(self,x):
        #print(x.shape) [10, 1, 224, 224]
        conv1 = self.layer1_conv(x)
        print(conv1.shape)
        pool1 = F.max_pool2d(conv1,2)

        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2,2)

        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3,2)

        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4,2)

        conv5 = self.layer5_conv(pool4)
        #print(conv5.shape) ([10, 128, 14, 14])
        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1,conv4],dim=1)
        conv6 = self.layer6_conv(concat1)

        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2,conv3],dim=1)
        conv7 = self.layer7_conv(concat2)

        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3,conv2],dim=1)
        conv8 = self.layer8_conv(concat3)

        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4,conv1],dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)

        return outp

model = Unet()
inp = torch.rand(10,1,224,224)
outp = model(inp)

(二)重叠平铺策略

该策略的思想是:对图像的某一块像素点(黄框内)进行预测时,需要该图像块周围的像素点(细蓝色框内)提供上下文信息,以获得更准确的预测。但是图像边界的图像块(黄色框内)没有周围像素,因此作者对周围像素采用了镜像扩充(比如粗蓝色框内的图像可以看出来是经过白线镜像对称得到的)。这样,边界图像块也能得到准确的预测。

(三)加权损失

该策略的思想是:为了更好的实现边界分割,对于细胞边缘的分割像素点加大了损失权重,使得网络更加重视边缘像素的学习。损失函数采用交叉熵损失函数的加权:

其中权重:

 softmax:

(四)随机弹性形变

该策略的思想是:实际上是一种数据增强方式。数据增强能够有效的提高网络的不变性和稳定性,由于医学数据通常是非常少的,因此数据增强就变得异常重要。而对于显微图像主要需要位移和旋转不变性以及对变形和灰度值变化的鲁棒性,于是文中使用了随机弹性形变。

———————————————————————————————————————————

【参考】:

如何理解u_net中的overlap_tile策略? - 知乎

图像分割必备知识点 | Unet详解 理论+ 代码 - 知乎

数据增强:弹性变形(Elastic Distortion) - 知乎

以上就是全部啦,如有疑问和错误欢迎私信~

<think>好的,用户想了解超分辨率中的浅层特征提取方法、技术和实现。首先,我需要回忆一下超分辨率的基本概念和相关技术。超分辨率(Super-Resolution, SR)是指从低分辨率图像重建高分辨率图像的技术,而浅层特征提取通常是网络中的前几层,负责捕捉低级特征,比如边缘、纹理等。 根据用户提供的引用,特别是引用[2]提到的SPIN方法,它使用超像素来分组局部像素,可能和特征提取有关联。但浅层特征一般更基础,可能涉及卷积层、残差块等。引用[4]中的RCAN结构提到了残差缩放块,这可能属于深层网络部分,但浅层可能使用简单的卷积层。 接下来,我需要整理常见的浅层特征提取方法。比如,经典的SRCNN开始部分用两个卷积层提取特征。后来的VDSR、EDSR等模型也采用类似结构,但增加了深度和残差连接。浅层通常包括初始卷积,可能使用较大的核尺寸来捕捉更广的上下文信息,比如7x7或5x5的卷积核,之后接ReLU激活。 另外,引用[3]提到FPN通过横向连接结合低级和高级特征,这可能影响特征提取的设计,但FPN更多是特征金字塔,属于结构设计的一部分。不过,在超分任务中,浅层特征的提取可能不需要复杂的金字塔结构,而是专注于基础特征的获取。 技术实现方面,浅层特征通常通过卷积层实现,可能包括多通道卷积,比如将输入图像映射到高维特征空间。例如,EDSR中使用两个3x3卷积提取浅层特征,然后传入深层网络。此外,一些方法可能使用预处理步骤,比如插值上采样后再处理,但现代方法如ESRGAN倾向于在低分辨率空间提取特征,最后上采样。 需要注意用户提到的引用中是否有其他相关技术。引用[2]中的超像素内注意力可能属于更高级的特征处理,可能在浅层之后。而浅层特征提取更注重基础,所以可能需要结合传统卷积操作和简单的激活函数。 最后,生成相关问题需要考虑用户可能的后续疑问,比如浅层与深层特征的区别、如何优化浅层特征提取、或特定网络结构中的实现差异等。确保问题与主题相关,并引导进一步学习。</think>### 超分辨率中的浅层特征提取方法和技术 #### 1. **浅层特征的定义与作用** 在超分辨率任务中,**浅层特征**通常指网络前几层提取的低级视觉特征,例如边缘、纹理、颜色分布等。这些特征是图像重建的基础,直接影响后续深层网络对高频细节的恢复能力[^3]。 #### 2. **常用方法与技术** ##### (1) **单/多卷积层提取** - **基础实现**:通过单个卷积层直接映射低分辨率图像到特征空间。例如,使用一个$3 \times 3$卷积核生成64通道特征: ```python self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1) ``` - **多级增强**:部分模型(如EDSR)采用多个卷积堆叠,增强浅层表达能力: ```python self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.conv2 = nn.Conv2d(64, 64, 3, padding=1) ``` ##### (2) **残差连接** 在浅层引入残差结构(如RCAN的残差缩放块[^4]),避免梯度消失并加速训练: ```python class ShallowResBlock(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(64, 64, 3, padding=1) self.conv2 = nn.Conv2d(64, 64, 3, padding=1) def forward(self, x): return x + self.conv2(nn.ReLU()(self.conv1(x))) ``` ##### (3) **注意力机制** 结合通道注意力(如ECA-Net)或空间注意力,聚焦重要特征区域: ```python class ChannelAttention(nn.Module): def __init__(self, channels): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channels, channels//4), nn.ReLU(), nn.Linear(channels//4, channels), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) ``` ##### (4) **超像素引导** 如SPIN方法[^2]通过**超像素内注意力**划分局部区域,增强浅层特征的语义一致性: - 超像素将图像分割为局部相似区域 - 注意力机制在超像素内聚合特征 #### 3. **实现示例(PyTorch)** ```python class ShallowFeatureExtractor(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, 5, padding=2) # 大核捕捉广域上下文 self.res_blocks = nn.Sequential( ShallowResBlock(), ShallowResBlock() ) self.attention = ChannelAttention(64) def forward(self, lr_img): x = nn.ReLU()(self.conv1(lr_img)) x = self.res_blocks(x) return self.attention(x) ``` #### 4. **关键设计原则** 1. **平衡感受野**:浅层卷积核尺寸通常选择$3 \times 3$或$5 \times 5$,兼顾局部细节与上下文 2. **轻量化设计**:避免过多参数导致过拟合(如RCAN的残差缩放块) 3. **特征保留**:减少下采样操作,保持空间分辨率 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值