UNETR模型讲解

UNETR模型详解

论文的arXiv 链接:https://arxiv.org/abs/2103.10504

简介

        UNETR(UNEt TRansformer)是一种基于Transformer架构的医学图像分割模型,结合了U-Net的编码器-解码器结构和Transformer的全局建模能力。它通过替换U-Net的卷积编码器为纯Transformer模块,直接从3D图像块序列中提取多尺度特征,同时保留U-Net的解码器进行渐进式上采样和定位。网络整体架构如下:

🟢 绿色立方体Deconv 2×2×2:3D 转置卷积,上采样 ×2
🔵 蓝色立方体Deconv 2×2×2 + Conv 3×3×3 + BN + ReLU:完整上采样+融合单元
🟡 黄色立方体Conv 3×3×3 + BN + ReLU:普通卷积层
🔵+🟢 组合表示一个“上采样 + 卷积融合”模块
🔵(带 c)拼接(Concatenation)操作
🟤 灰色立方体Conv 1×1×1:最终分类头

        UNETR 保留了 U-Net 的经典 对称 U 型结构:

(1)编码器(下采样路径):逐级提取高层语义特征;
(2)解码器(上采样路径):解码器仍是 CNN,非 Transformer,逐步恢复空间分辨率;
(3)跳跃连接:将编码器各阶段的多尺度特征传递给对应层级的解码器,用于融合局部细节与全局语义。

流程

        下面我以一个例子简单讲解一下整个的计算流程以及尺寸变化。我们假设:

        输入图像尺寸:96 × 96 × 96(H × W × D,体素空间)
        输入通道数:1(如单模态 CT)
        目标分割类别数:14(如 MSD 数据集中的多器官分割)

Patch Embedding(分块嵌入)

        将输入划分为 non-overlapping 3D patches。每个 patch 大小:16×16×16;每个 patch 体素数:16 × 16 × 16 = 4096;总 patch 数:(96//16) × (96//16) × (96//16) = 6 × 6 × 6 = 216。

        每个 patch 展平为 16×16×16 = 4096 维向量,通过线性投影(embedding)映射到隐空间维度 C = 768(标准 ViT 设置)。
输出张量:
        Token 序列 [B, N=216, C=768],加上位置编码(可学习)后输入 Transformer。(B是批次大小batch_size)

Transformer 编码器(多层 ViT)

        12 层标准 Transformer blocks。

        关键输出层:取第 3、6、9、12 层的输出用于跳跃连接(对应 4 个尺度)

层级输出 token 尺寸对应“虚拟”空间分辨率
ViT-3[B, 216, 768]6×6×6
ViT-6[B, 216, 768]6×6×6
ViT-9[B, 216, 768]6×6×6
ViT-12[B, 216, 768]6×6×6

多尺度特征重建(从 token 到 3D 特征图)重点!

        这一步是UNETR 的关键技巧,将固定数量的 token 重建为不同分辨率的 3D 特征图,以便与解码器的上采样路径对接。

        使用 3D 卷积转置(ConvTranspose3d)或插值 + 卷积 逐步上采样。采用分阶段上采样模块,从 ViT-12 开始,逐级融合并上采样:

阶段来源重建后特征图尺寸通道数
S4(最深层)ViT-126 × 6 × 6 (96/16)768 → 32
S3ViT-9 + 上采样(S4)12 × 12 × 12(96/8)32
S2ViT-6 + 上采样(S3)24 × 24 × 24(96/4)32
S1(最浅层)ViT-3 + 上采样(S2)48 × 48 × 48(96/2)32

CNN 解码器(上采样 + 跳跃连接融合)

阶段来源(ViT 层)解码器输入操作步骤(按顺序)输出尺寸(B=1)
D4
(最深层)
ViT-12(Z₁₂)[1, 32, 6, 6, 6]1. 上采样 ×2(Deconv 2×2×2)
2. 拼接 S3(来自 ViT-9 上采样至 12³)
3. 3×3×3 Conv + BN + ReLU(融合)
[1, 32, 12, 12, 12]
D3ViT-9(Z₉)→ S3[1, 32, 12, 12, 12]1. 上采样 ×2
2. 拼接 S2(ViT-6 上采样至 24³)
3. 3×3×3 Conv 融合
[1, 32, 24, 24, 24]
D2ViT-6(Z₆)→ S2[1, 32, 24, 24, 24]1. 上采样 ×2
2. 拼接 S1(ViT-3 上采样至 48³)
3. 3×3×3 Conv 融合
[1, 32, 48, 48, 48]
D1
(输出层)
ViT-3(Z₃)→ S1[1, 32, 48, 48, 48]1. 上采样 ×2 → 96³
2. 1×1×1 Conv(分类头)
[1, C=类别数, 96, 96, 96]

        S1、S2、S3、S4是 ViT 中间层输出经过上采样和通道调整后的特征图,用于跳跃连接:

S1 ← ViT-3 → 上采样 ×8 → [48,48,48]

S2 ← ViT-6 → 上采样 ×4 → [24,24,24]

S3 ← ViT-9 → 上采样 ×2 → [12,12,12]

S4 ← ViT-12 → 无上采样 → [6,6,6](作为 D4 的初始输入)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值