【图像分割】TransUNet学习笔记

论文名称:TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation
论文地址:https://arxiv.org/pdf/2102.04306.pdf
代码地址:https://github.com/Beckschen/TransUNet


前言:

TransUNet将Transformer和U-Net结合了起来。由于卷积操作本身存在的局限性,U-Net不能很好地建模长距离依赖关系,而Transformer这种全局自注意力机制可以有效地获取全局信息,但其对于低层次细节信息获取不充分,导致其定位能力方面受到限制。所以作者将这两者结合起来,提出了TransUNet网络。同时为了能将高分辨率的特征图通过跳级连接与上采样后的特征图联合以获得充分的信息,作者没有像ViT那样将原图直接打成patch块输入到Transformer模块中,而是先通过CNN进行特征提取得到特征图,再将其变换后输入Transformer编码器模块,最后仿照U-Net解码器逐级上采样并进行跳级连接,最后得到分割结果。


总体结构:

TransUNet总体上还是一个U型的Encoder-Decoder结构。在编码器部分,将原图输入CNN进行特征提取,线性投影之后进行Patch Embedding将特征图序列化并加上位置编码,输入transformer编码器。在解码器部分,将编码器输出的序列进行reshape然后通过1x1卷积变换通道数之后进行级联上采样,中途通过编码器CNN的各级分辨率特征图进行跳级连接,最后得到分割结果,这部分与U-Net解码器类似。

在论文中作者说道刚开始是直接应用Transformer编码器对原图进行编码,然后将输出的特征图直接上采样到原分辨率,但效果并不是最好的。作者分析输入编码器的\frac{H}{P}\times \frac{H}{P}分辨率对于原图HxW的分辨率来说还是太小了,导致损失了一些低层次的细节信息,比如边界信息。因此作者应用了一个联合CNN-Transformer的结构作为编码器,并在解码器中加入可以获得精确位置信息的级联上采样操作。

作者选用CNN-Transformer这一混合结构设计的原因有两点:1)为了将中间高分辨率的CNN特征图加入到解码器路径中以获得更多的信息以及更精确的位置。2)作者发现使用CNN-Transformer编码器要比但纯的Transformer编码器效果要好。(个人不理解的地方:TransUNet和SETR基本都是受ViT启发,对比了 CNN-Transformer Hybrid 和pure Transformer,为什么TransUNet说混合模型更好而SETR说纯Transformer模型更好呢?)
 


编码器:

ViT+ResNet50

这里的ResNet50与原ResNet50有些不同,首先卷积采用的是StdConv2d而不是传统的Conv2d,然后是用GroupNorm层代替了原来的BatchNorm层,然后我在代码中看到BottleNeck层也变成了PreActivation版本,也就是将ReLU和Normalization层前置了。在原Resnet50网络中,stage1有3个重复堆叠的Block,stage2中是4个,stage3中是6个,stage4中是3个,但在这里的ResNet50中,把stage4中的3个Block移至stage3中,所以stage3中有9个重复堆叠的Block。还有原ResNet50输出的特征图分辨率是从224x224降低到了7x7,输出为原图的1/32,而TransUNet中输出特征图分辨率为14x14,只为原图的1/16,应该是在将stage4中的3个Block拿到stage3中的时候将stride=2改成了stride=1以去掉降采样操作。

维度变化:经过Stem,分辨率变为原图1/4,[224, 224, 3]-->[56, 56, 64]。经过Stage1,分辨率不变,[56, 56, 64]-->[56, 56, 256]。经过Stage2,变为原图1/8,[56, 56, 256]-->[28, 28, 512]。经过Stage3,变为原图1/16,[28, 28, 512]-->[14, 14, 1024]。然后通过一个1x1的卷积缩减维度后进行序列化输入Transformer,[14, 14, 1024]-->[14, 14, 768]-->[196, 768],这里的维度就是Transformer需要的序列的维度,也就是论文中由[H, W, 3]变为了[\frac{HW}{P^{2}}, P^{2}\cdot C]。N= \frac{HW}{P^{2}} = 196 即为序列的长度。

最后经过Patch Embedding和位置编码,经过Patch Embedding后,[196, 768] x [768, 768]-->[196, 768],也即[\frac{HW}{P^{2}}, P^{2}\cdot C] x [P^{2}\cdot C, D]-->[\frac{HW}{P^{2}}, D]。

这里的E\in R^{(P^{2}\cdot C)\times D}是一个线性投影,目的是将序列映射到[N, D]。

更多Transformer编码器的细节后续再记录。


解码器:

编码器输出的特征图为Z_{L}\in R^{\frac{HW}{P^{2}}\times D},将其进行Reshape,然后应用1x1卷积进行维度缩减。[196, 768]-->[14, 14, 768]-->[14, 14, 512],也即论文图中的[N, D]-->[\frac{H}{P}, \frac{H}{P},D]-->[\frac{H}{P}\frac{H}{P}, 512]。后续操作和U-Net几乎相同。

,  


整体结构和代码对应图,感谢原作者!


由于本人水平非常有限,如有错误,恳请指正,欢迎大家一起交流学习!


参考:

(基础)CNN网络结构_Chan_Zeng的博客-优快云博客

Vision Transformer详解_霹雳吧啦Wz-优快云博客_wz框架

TransUnet: 结构解析_ripple970227的博客-优快云博客_unet结构详解

TransUNet_zjiafbaodaozmj的博客-优快云博客_transunet

### TransUNet 深度学习模型架构 TransUNet 是一种用于医学图像分割的强大编码器结构,该模型融合了卷积神经网络(CNNs)和变换器(Transformers),旨在提高图像分割任务的效果[^4]。 #### 架构特点 1. **混合架构设计** - TransUNet 结合了 CNN 和 Transformer 的优势。具体来说,在编码阶段采用基于 ResNet 或者 Swin Transformer 的骨干网来提取局部特征;而在解码部分则利用跳跃连接机制恢复空间分辨率并整合多尺度上下文信息。 2. **引入 DA-Blocks** - 创新性地提出了 DA-Blocks 来加强位置感知能力和通道间依赖关系的学习效果。这些模块被嵌入到跳跃连接路径中,从而增强了整个框架对于复杂场景的理解力[^3]。 3. **高效训练策略** - 训练过程中采用了精心设计的数据增强技术和损失函数优化方案,使得模型能够更有效地捕捉目标区域边界,并减少过拟合现象的发生。此外,合理的超参数调整也促进了更快收敛速度以及更好的泛化表现[^1]。 ```python import torch.nn as nn class TransUNet(nn.Module): def __init__(self, img_size=224, in_channels=3, out_channels=1, ... ): super().__init__() self.encoder = Encoder(img_size=img_size, in_chans=in_channels,... ) # 使用ResNet/Swin作为backbone self.decoder = Decoder(num_classes=out_channels,... ) def forward(self,x): features = self.encoder(x) output = self.decoder(features) return output ``` ### 应用领域 TransUNet 主要应用于医疗影像分析方面,特别是在处理高维、细粒度的任务上表现出色。例如: - 对于 CT 扫描得到的心脏血管造影图片进行精确切割; - 针对 MRI 影像实现脑部肿瘤自动检测与轮廓描绘等功能。 通过对不同三维模型之间推理时间、训练耗时、浮点运算次数(FLOPs) 及显存占用情况对比可以发现,尽管 TransUNet 在某些指标上可能不是最优的选择,但在综合考量精度与时效性的前提下仍然具有明显的优势[^2]。
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值