paper:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
1、Transformer Interactive Fusion
对于一些传统的特征融合方法(如简单拼接),这些会导致无法有效捕捉不同尺度特征之间的长程依赖和全局上下文信息,最终会导致分割性能受限。而现有基于Transformer的分割模型主要集中在编码器部分,而解码器仍然使用CNN,无法充分利用Transformer的优势。直接拼接多尺度特征会导致特征错位和语义差距,影响分割精度。为此,这篇论文提出一种 Transformer交互融合模块(Transformer Interactive Fusion module)。
TIF 模块通过利用 Transformer 的自注意力机制,在不同尺度特征之间建立长程依赖关系,并有效融合多尺度上下文信息。引入标准 Transformer块 而非 Swin Transformer块,这能使 TIF 模块能够灵活地处理不同分支的特征图,并进行有效的特征融合。
对于两个分支的特征图分别为 F 和 G,TIF 的实现过程:
- 全局抽象: 将 G 特征图进行全局平均池化,并将其投影到一个较低维度的向量中,作为 G 的全局抽象信息。这个向量代表 G 特征图的整体信息,可以与 F 特征图中的每个像素进行交互。
- 序列拼接: 将 F 特征图与 G 的全局抽象信息进行拼接,形成一个包含多个token的序列。每个token代表 F 特征图中的一个像素,并与 G 的全局抽象信息相连。
- 自注意力计算: 将拼接后的序列送入标准Transformer块,进行自注意力计算。在自注意力机制中,每个token会根据其与其他token之间的相似度进行加权,从而获得一个更全面的特征表示。这样,F 特征图中的每个像素都能够获得来自 G 特征图的全局信息,从而提高了分割精度。
- 特征图恢复: 将Transformer块的输出进行切片,得到与 F 特征图分辨率相同的融合特征图。
Transformer Interactive Fusion 结构图:
2、代码实现
import torch
from torch import nn, einsum
from einops.einops import rearrange
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class TIF(nn.Module):
def __init__(self, dim_s, dim_l):
super().__init__()
self.transformer_s = Transformer(dim=dim_s, depth=1, heads=3, dim_head=32, mlp_dim=128)
self.transformer_l = Transformer(dim=dim_l, depth=1, heads=1, dim_head=64, mlp_dim=256)
self.norm_s = nn.LayerNorm(dim_s)
self.norm_l = nn.LayerNorm(dim_l)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.linear_s = nn.Linear(dim_s, dim_l)
self.linear_l = nn.Linear(dim_l, dim_s)
def forward(self, e, r):
b_e, c_e, h_e, w_e = e.shape
e = e.reshape(b_e, c_e, -1).permute(0, 2, 1)
b_r, c_r, h_r, w_r = r.shape
r = r.reshape(b_r, c_r, -1).permute(0, 2, 1)
e_t = torch.flatten(self.avgpool(self.norm_l(e).transpose(1,2)), 1)
r_t = torch.flatten(self.avgpool(self.norm_s(r).transpose(1,2)), 1)
e_t = self.linear_l(e_t).unsqueeze(1)
r_t = self.linear_s(r_t).unsqueeze(1)
r = self.transformer_s(torch.cat([e_t, r],dim=1))[:, 1:, :]
e = self.transformer_l(torch.cat([r_t, e],dim=1))[:, 1:, :]
e = e.permute(0, 2, 1).reshape(b_e, c_e, h_e, w_e)
r = r.permute(0, 2, 1).reshape(b_r, c_r, h_r, w_r)
return e + r
if __name__ == '__main__':
x = torch.randn(4, 512, 7, 7).cuda()
y = torch.randn(4, 512, 7, 7).cuda()
model = TIF(512, 512).cuda()
out = model(x,y)
print(out.shape)