模块出处
[ICCV 23] [link] [code] Scale-Aware Modulation Meet Transformer
模块名称
Scale-Aware Modulation (SAM)
模块作用
改进的自注意力
模块结构
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class SAM(nn.Module):
def __init__(self, dim, ca_num_heads=4, sa_num_heads=8, qkv_bias=False, qk_scale=None,
attn_drop=0., proj_drop=0., expand_ratio=2):
super().__init__()
self.ca_attention = 1
self.dim = dim
self.ca_num_heads = ca_num_heads
self.sa_num_heads = sa_num_heads
assert dim % ca_num_heads == 0, f"dim {
dim} should be divided by num_heads