【论文阅读】U-Net论文详解

本文详细介绍了U-Net网络结构,该网络在图像语义分割任务中表现出色,尤其适合从少量图像进行端到端训练。U-Net结合了压缩路径(捕获上下文信息)和扩展路径(支持精确定位),解决了滑动窗口卷积网络的速度和定位精度问题。网络由全卷积层构成,通过编码器和解码器实现特征的下采样和上采样,同时利用跳跃连接保持细节信息。Pytorch实现中,使用了Encoder和Decoder模块,并在实验中采用了数据增强来提高模型的鲁棒性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

U-Net论文详解

UNet算法Pytorch实现:https://github.com/codecat0/CV/tree/main/Semantic_Segmentation/UNet

U-Net结构由一个用于捕获上下文信息的压缩路径和一个支持精确定位的对称扩展路径构成。实验结果表明可以从很少的图像进行端到端的训练,并在ISBI挑战上优于先前最优的方法(滑动窗口卷积网络),并获得了冠军

1. 背景介绍

卷积网络的典型应用是分类任务,其中图像的输出是一个单一的类标签。然而在许多视觉任务中,特别是生物医学图像处理中,期望的输出应该包含定位,即给每一个像素点分配一个类标签。

于是滑动窗口卷积网络通过提供像素点周围的局部区域来预测每个像素的类别标签。但是这样的方法存在两个缺点:

  1. 速度特别慢,网络必须为每一个窗口单元单独运行,并且窗口单元重合而导致大量冗余
  2. 在定位精度和上下文信息之间的权衡。大的窗口单元需要更多的max pooling层,这会降低精度;而小的窗口单元捕获的上下文信息较少。

于是本文提出了U-Net网络

2. U-Net网络架构

在这里插入图片描述

网络是一个经典的全卷积网络。网络的输入是一张572x572经过镜像操作的图像。为了使得每次下采样后特征图的尺寸为偶数。
在这里插入图片描述

网络的左侧为压缩路径,由4个block构成,每个block由2个未padding的卷积和一个最大池化构成,其中每次卷积特征图的尺寸为减小2,最大池化后会缩小一半。

现在大部分采用same padding的卷积,这样就不用对输入进行镜像操作,而且在拼接压缩路径与对应的扩展路径也不用进行裁剪,而且裁剪会使得特征图不对称

网络的右侧为扩展路径,同样由4个block构成,每个block开始之前通过反卷积将特征图的尺寸扩大一倍,然后与压缩路径对应的特征图拼接,由于采用未padding的卷积,左侧压缩路径的特征图的尺寸比右侧扩展路径的特征图的大,所以需要先进行裁剪,使其大小相同,然后拼接,然后经过两次未padding的卷积进一步提取特征

最后根据自己的任务,输出对应大小的预测特征图

现在大部分采用双线性插值代替反卷积,而且效果会更好

3. 数据增强

我们主要通过平移和旋转不变性以及灰度值的变化来增强模型的鲁棒性,特别地,任意的弹性形变对训练非常有帮助

4. Pytorch实现

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x_pooled = self.pool(x)
        return x, x_pooled


class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.up_sample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x_prev, x):
        x = self.up_sample(x)
        x_shape = x.shape[2:]
        x_prev_shape = x.shape[2:]
        h_diff = x_prev_shape[0] - x_shape[0]
        w_diff = x_prev_shape[1] - x_shape[1]
        # padding
        x_tmp = torch.zeros(x_prev.shape).to(x.device)
        x_tmp[:, :, h_diff//2: h_diff+x_shape[0], w_diff//2: x_shape[1]] = x
        x = torch.cat([x_prev, x_tmp], dim=1)
        x = self.block1(x)
        x = self.block2(x)
        return x



class UNet(nn.Module):
    # https://arxiv.org/abs/1505.04597
    def __init__(self, num_classes=2):
        super(UNet, self).__init__()

        self.down_sample1 = Encoder(in_channels=3, out_channels=64)
        self.down_sample2 = Encoder(in_channels=64, out_channels=128)
        self.down_sample3 = Encoder(in_channels=128, out_channels=256)
        self.down_sample4 = Encoder(in_channels=256, out_channels=512)

        self.mid1 = nn.Sequential(
            nn.Conv2d(512, 1024, 3, bias=False),
            nn.ReLU(inplace=True)
        )
        self.mid2 = nn.Sequential(
            nn.Conv2d(1024, 1024, 3, bias=False),
            nn.ReLU(inplace=True)
        )

        self.up_sample1 = Decoder(in_channels=1024, out_channels=512)
        self.up_sample2 = Decoder(in_channels=512, out_channels=256)
        self.up_sample3 = Decoder(in_channels=256, out_channels=128)
        self.up_sample4 = Decoder(in_channels=128, out_channels=64)

        self.classifier = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        x1, x = self.down_sample1(x)
        x2, x = self.down_sample2(x)
        x3, x = self.down_sample3(x)
        x4, x = self.down_sample4(x)

        x = self.mid1(x)
        x = self.mid2(x)

        x = self.up_sample1(x4, x)
        x = self.up_sample2(x3, x)
        x = self.up_sample3(x2, x)
        x = self.up_sample4(x1, x)

        x = self.classifier(x)
        return x



if __name__ == '__main__':
    input = torch.rand(1, 3, 384, 384)
    model = UNet(2)
    out = model(input)
    print(out.shape)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值