(即插即用模块-特征处理部分) 十八、(TIM 2022) TIF Transformer交互融合模块

在这里插入图片描述

paper:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation

Code:https://github.com/TianBaoGe/DS-TransUNet


1、Transformer Interactive Fusion

对于一些传统的特征融合方法(如简单拼接),这些会导致无法有效捕捉不同尺度特征之间的长程依赖和全局上下文信息,最终会导致分割性能受限。而现有基于Transformer的分割模型主要集中在编码器部分,而解码器仍然使用CNN,无法充分利用Transformer的优势。直接拼接多尺度特征会导致特征错位和语义差距,影响分割精度。为此,这篇论文提出一种 Transformer交互融合模块(Transformer Interactive Fusion module)

TIF 模块通过利用 Transformer 的自注意力机制,在不同尺度特征之间建立长程依赖关系,并有效融合多尺度上下文信息。引入标准 Transformer块 而非 Swin Transformer块,这能使 TIF 模块能够灵活地处理不同分支的特征图,并进行有效的特征融合。

对于两个分支的特征图分别为 F 和 G,TIF 的实现过程:

  1. 全局抽象: 将 G 特征图进行全局平均池化,并将其投影到一个较低维度的向量中,作为 G 的全局抽象信息。这个向量代表 G 特征图的整体信息,可以与 F 特征图中的每个像素进行交互。
  2. 序列拼接: 将 F 特征图与 G 的全局抽象信息进行拼接,形成一个包含多个token的序列。每个token代表 F 特征图中的一个像素,并与 G 的全局抽象信息相连。
  3. 自注意力计算: 将拼接后的序列送入标准Transformer块,进行自注意力计算。在自注意力机制中,每个token会根据其与其他token之间的相似度进行加权,从而获得一个更全面的特征表示。这样,F 特征图中的每个像素都能够获得来自 G 特征图的全局信息,从而提高了分割精度。
  4. 特征图恢复: 将Transformer块的输出进行切片,得到与 F 特征图分辨率相同的融合特征图。

Transformer Interactive Fusion 结构图:
在这里插入图片描述


2、代码实现

import torch
from torch import nn, einsum
from einops.einops import rearrange


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
    

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
            
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class TIF(nn.Module):
    def __init__(self, dim_s, dim_l):
        super().__init__()
        self.transformer_s = Transformer(dim=dim_s, depth=1, heads=3, dim_head=32, mlp_dim=128)
        self.transformer_l = Transformer(dim=dim_l, depth=1, heads=1, dim_head=64, mlp_dim=256)
        self.norm_s = nn.LayerNorm(dim_s)
        self.norm_l = nn.LayerNorm(dim_l)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.linear_s = nn.Linear(dim_s, dim_l)
        self.linear_l = nn.Linear(dim_l, dim_s)

    def forward(self, e, r):
       b_e, c_e, h_e, w_e = e.shape
       e = e.reshape(b_e, c_e, -1).permute(0, 2, 1)
       b_r, c_r, h_r, w_r = r.shape
       r = r.reshape(b_r, c_r, -1).permute(0, 2, 1)
       e_t = torch.flatten(self.avgpool(self.norm_l(e).transpose(1,2)), 1)
       r_t = torch.flatten(self.avgpool(self.norm_s(r).transpose(1,2)), 1)
       e_t = self.linear_l(e_t).unsqueeze(1)
       r_t = self.linear_s(r_t).unsqueeze(1)
       r = self.transformer_s(torch.cat([e_t, r],dim=1))[:, 1:, :]
       e = self.transformer_l(torch.cat([r_t, e],dim=1))[:, 1:, :]
       e = e.permute(0, 2, 1).reshape(b_e, c_e, h_e, w_e)
       r = r.permute(0, 2, 1).reshape(b_r, c_r, h_r, w_r)
       return e + r


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7).cuda()
    y = torch.randn(4, 512, 7, 7).cuda()
    model = TIF(512, 512).cuda()
    out = model(x,y)
    print(out.shape)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值