模块出处
[CVPR 22] [link] [code] Pyramid Grafting Network for One-Stage High Resolution Saliency Detection
模块名称
Cross-Model Grafting Module (CMGM)
模块作用
Transformer与CNN之间的特征融合
模块结构
模块思想
Transformer在全局特征上更优,CNN在局部特征上更优,对这两者进行进行融合的最简单做法是直接相加或相乘。但是,相加或相乘本质上属于"局部"操作,如果某片区域两个特征的不确定性都较高,则会带来许多噪声。为此,本文提出了CMGM模块,通过交叉注意力的形式引入更为广泛的信息来增强融合效果。
模块代码
import torch.nn.functional as F
import torch.nn as nn
import torch
class CMGM(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.k = nn.Linear(dim, dim , bias=qkv_bias)
self.qv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.proj = nn.Linear