各种注意力机制PyTorch实现

这篇博客汇总了多种注意力机制的PyTorch实现,包括外部注意力、自注意力、简化的自注意力、Squeeze-and-Excitation、Selective Kernel、CBAM、BAM、ECA等。同时,还介绍了RepMLP、MLP-Mixer、ResMLP和gMLP等高效多头自注意力模型。这些实现涵盖了从图像识别到场景理解的多个视觉任务。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

给出了整个系列的PyTorch的代码实现,以及使用方法。

各种注意力机制
Pytorch implementation of “Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks—arXiv 2020.05.05”

Pytorch implementation of “Attention Is All You Need—NIPS2017”

Pytorch implementation of “Squeeze-and-Excitation Networks—CVPR2018”

Pytorch implementation of “Selective Kernel Networks—CVPR2019”

Pytorch implementation of “CBAM: Convolutional Block Attention Module—ECCV2018”

Pytorch implementation of “BAM: Bottleneck Attention Module—BMCV2018”

Pytorch implementation of “ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks—CVPR2020”

Pytorch implementation of “Dual Attention Network for Scene Segmentation—CVPR2019”

Pytorch implementation of “EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network—arXiv 2020.05.30”

Pytorch implementation of “ResT: An Efficient Transformer for Visual Recognition—arXiv 2020.05.28”

  1. 外部注意力
    1.1. 论文
    “Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks”

1.2. 概要

在这里插入图片描述
1.3. 代码

from attention.ExternalAttention import ExternalAttention
import torch
 
input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)
  1. 自注意力
    2.1. 论文
    “Attention Is All You Need”

1.2. 概要

在这里插入图片描述
1.3. 代码

from attention.SelfAttention import ScaledDotProductAttention
import torch
 
input=torch.randn(50,49,512)
sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)
  1. 简化的自注意力
    3.1. 论文
    None

3.2. 概要

在这里插入图片描述
3.3. 代码

from attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
import torch
 
input=torch.randn(50,49,512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output=ssa(input,input,input)
print(output.shape)
 
  1. Squeeze-and-Excitation 注意力
    4.1. 论文
    “Squeeze-and-Excitation Networks”

4.2. 概要

在这里插入图片描述
4.3. 代码

from attention.SEAttention import SEAttention
import torch
 
input=torch.randn(50,512,7,7)
se = SEAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)
  1. SK 注意力
    5.1. 论文
    “Selective Kernel Networks”

5.2. 概要

在这里插入图片描述
5.3. 代码

from attention.SKAttention import SKAttention
import torch
 
input=torch.randn(50,512,7,7)
se = SKAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)
  1. CBAM 注意力
    6.1. 论文
    “CBAM: Convolutional Block Attention Module”

6.2. 概要

在这里插入图片描述
在这里插入图片描述
6.3. 代码

from attention.CBAM import CBAMBlock
import torch
 
input=torch.randn(50,512,7,7)
kernel_size=input.shape[2]
cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)
output=cbam(input)
print(output.shape)
 
  1. BAM 注意力
    7.1. 论文
    “BAM: Bottleneck Attention Module”

7.2. 概要

在这里插入图片描述
7.3. 代码

from attention.BAM import BAMBlock
import torch
 
input=torch.randn(50,512,7,7)
bam = BAMBlock(channel=512,reduction=16,dia_val=2)
output=bam(input)
print(output.shape)
 
  1. ECA 注意力
    8.1. 论文
    “ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks”

8.2. 概要

在这里插入图片描述
8.3. Code

from attention.ECAAttention import ECAAttention
import torch
 
input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)
 
  1. DANet 注意力
    9.1. 论文
    “Dual Attention Network for Scene Segmentation”

9.2. 概要

在这里插入图片描述
9.3. 代码

from attention.DANet import DAModule
import torch
 
if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
    print(danet(input).shape)
 
  1. 金字塔拆分注意力
    10.1. 论文
    “EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network”

10.2. 概要

在这里插入图片描述
10.3. 代码

from attention.PSA import PSA
import torch
 
if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    psa = PSA(channel=512,reduction=8)
    output=psa(input)
    print(output.shape)
 
  1. 高效多头自注意力
    11.1. 论文
    “ResT: An Efficient Transformer for Visual Recognition”

11.2. 概要

在这里插入图片描述
11.3. 代码

from attention.EMSA import EMSA
import torch
from torch import nn
from torch.nn import functional as F
 
if __name__ == '__main__':
    input=torch.randn(50,64,512)
    emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)
    output=emsa(input,input,input)
    print(output.shape)
    

MLP 系列
Pytorch implementation of “RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition—arXiv 2020.05.05”

Pytorch implementation of “MLP-Mixer: An all-MLP Architecture for Vision—arXiv 2020.05.17”

Pytorch implementation of “ResMLP: Feedforward networks for image classification with data-efficient training—arXiv 2020.05.07”

Pytorch implementation of “Pay Attention to MLPs—arXiv 2020.05.17”

  1. RepMLP
    1.1. 论文
    “RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition”

1.2. 概要

在这里插入图片描述
1.3. 代码

from mlp.repmlp import RepMLP
import torch
from torch import nn
 
N=4 #batch size
C=512 #input dim
O=1024 #output dim
H=14 #image height
W=14 #image width
h=7 #patch height
w=7 #patch width
fc1_fc2_reduction=1 #reduction ratio
fc3_groups=8 # groups
repconv_kernels=[1,3,5,7] #kernel list
repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)
x=torch.randn(N,C,H,W)
repmlp.eval()
for module in repmlp.modules():
    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
        nn.init.uniform_(module.running_mean, 0, 0.1)
        nn.init.uniform_(module.running_var, 0, 0.1)
        nn.init.uniform_(module.weight, 0, 0.1)
        nn.init.uniform_(module.bias, 0, 0.1)
 
#training result
out=repmlp(x)
#inference result
repmlp.switch_to_deploy()
deployout = repmlp(x)
 
print(((deployout-out)**2).sum())
  1. MLP-Mixer
    2.1. 论文
    “MLP-Mixer: An all-MLP Architecture for Vision”

2.2. 概要

在这里插入图片描述
2.3. 代码

from mlp.mlp_mixer import MlpMixer
import torch
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
input=torch.randn(50,3,40,40)
output=mlp_mixer(input)
print(output.shape)
  1. ResMLP
    3.1. 论文
    “ResMLP: Feedforward networks for image classification with data-efficient training”

3.2. 概要

在这里插入图片描述
3.3. 代码

from mlp.resmlp import ResMLP
import torch
 
input=torch.randn(50,3,14,14)
resmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)
out=resmlp(input)
print(out.shape) #the last dimention is class_num
  1. gMLP
    4.1. 论文
    “Pay Attention to MLPs”

4.2. 概要
在这里插入图片描述
4.3. 代码

from mlp.g_mlp import gMLP
import torch
 
num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)

Re-Parameter 系列
Pytorch implementation of “RepVGG: Making VGG-style ConvNets Great Again—CVPR2021”

Pytorch implementation of “ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks—ICCV2019”

  1. RepVGG
    1.1. 论文
    “RepVGG: Making VGG-style ConvNets Great Again”

1.2. 概要

在这里插入图片描述
1.3. 代码

from rep.repvgg import RepBlock
import torch
 
 
input=torch.randn(50,512,49,49)
repblock=RepBlock(512,512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())
  1. ACNet
    2.1. 论文
    “ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks”

2.2. 概要

在这里插入图片描述
2.3. 代码

from rep.acnet import ACNet
import torch
from torch import nn
 
input=torch.randn(50,512,49,49)
acnet=ACNet(512,512)
acnet.eval()
out=acnet(input)
acnet._switch_to_deploy()
out2=acnet(input)
print('difference:')
print(((out2-out)**2).sum())
 
### 实现多头注意力机制 在神经机器翻译和其他序列建模任务中,多头注意力机制允许模型关注输入的不同部分来改善性能[^1]。下面展示了一个基于PyTorch框架下实现多头注意力机制的方法。 ```python import torch import torch.nn as nn import math class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" # 定义线性变换层WQ, WK, WV 和WO self.d_k = d_model // num_heads self.num_heads = num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def scaled_dot_product_attention(self, Q, K, V, mask=None): attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(K.size(-1)) if mask is not None: attn_scores = attn_scores.masked_fill(mask == 0, -1e9) attn_probs = torch.softmax(attn_scores, dim=-1) output = torch.matmul(attn_probs, V) return output def split_heads(self, x): batch_size, seq_len, _ = x.size() return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) def combine_heads(self, x): batch_size, _, seq_len, d_k = x.size() return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * d_k) def forward(self, query, key, value, mask=None): Q = self.split_heads(self.W_q(query)) K = self.split_heads(self.W_k(key)) V = self.split_heads(self.W_v(value)) attn_output = self.scaled_dot_product_attention(Q, K, V, mask=mask) combined_attn_output = self.combine_heads(attn_output) final_output = self.W_o(combined_attn_output) return final_output ``` 此代码片段定义了`MultiHeadAttention`类,该类实现了多头自注意力机制的核心逻辑。通过将查询、键和值向量分割成多个头部并分别计算注意分数,可以捕捉到不同位置之间的关系模式。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值