本文接续上文介绍YOLOv10如何添加其他先进注意力机制 。超30种注意力机制模块,助力文章涨点多多🛩️🛩️
提示:喜欢本专栏的小伙伴,请多多点赞关注支持。本文仅供学习交流使用,创作不易,未经作者允许,不得搬运或转载!!!
注意力机制模块介绍🛩️🛩️
16、 SGE🌱🌱
论文地址:https://arxiv.org/pdf/1905.09646
SGE注意力机制通过在特征图的空间维度上进行分组,并在每个组内应用注意力操作,来增强特征的表达能力。它通过对不同空间位置的特征进行有选择性的增强,从而提升模型对语义信息的捕捉能力。其核心思想是将特征图在空间维度上划分为多个组,对每个组内的特征进行注意力增强。这样可以捕捉到局部空间区域的特征相关性,提高特征表示的语义丰富度。
优势:
- 局部特征增强:通过在空间维度上分组并应用注意力机制,SGE能够捕捉到局部区域的特征相关性,提高特征表示的语义丰富度。
- 计算效率高:SGE只在局部组内进行注意力计算,减少了计算复杂度,适合高效模型的设计。
- 易于集成:SGE的结构设计简单,容易集成到现有的卷积神经网络中,提升其性能。
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SpatialGroupEnhance(nn.Module):
def __init__(self, groups=8):
super().__init__()
self.groups=groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.weight=nn.Parameter(torch.zeros(1,groups,1,1))
self.bias=nn.Parameter(torch.zeros(1,groups,1,1))
self.sig=nn.Sigmoid()
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, h,w=x.shape
x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w
xn=x*self.avg_pool(x) #bs*g,dim//g,h,w
xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w
t=xn.view(b*self.groups,-1) #bs*g,h*w
t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w
std=t.std(dim=1,keepdim=True)+1e-5
t=t/std #bs*g,h*w
t=t.view(b,self.groups,h,w) #bs,g,h*w
t=t*self.weight+self.bias #bs,g,h*w
t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w
x=x*self.sig(t)
x=x.view(b,c,h,w)
return x
17、 A2Attention( Double Attention)🌱🌱
l论文地址:https://arxiv.org/pdf/1810.11579
A2Attention一种双重注意力机制,通过同时在空间和通道维度上进行注意力操作,来提升特征表示的质量。这种机制不仅能够捕捉到空间维度的上下文信息,还能够关注通道间的重要性,从而提供更加全面的特征增强。
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
class DoubleAttention(nn.Module):
def __init__(self, in_channels,c_m=128,c_n=128,reconstruct = True):
super().__init__()
self.in_channels=in_channels
self.reconstruct = reconstruct
self.c_m=c_m
self.c_n=c_n
self.convA=nn.Conv2d(in_channels,c_m,1)
self.convB=nn.Conv2d(in_channels,c_n,1)
self.convV=nn.Conv2d(in_channels,c_n,1)
if self.reconstruct:
self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size = 1)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, h,w=x.shape
assert c==self.in_channels
A=self.convA(x) #b,c_m,h,w
B=self.convB(x) #b,c_n,h,w
V=self.convV(x) #b,c_n,h,w
tmpA=A.view(b,self.c_m,-1)
attention_maps=F.softmax(B.view(b,self.c_n,-1))
attention_vectors=F.softmax(V.view(b,self.c_n,-1))
# step 1: feature gating
global_descriptors=torch.bmm(tmpA,attention_maps.permute(0,2,1)) #b.c_m,c_n
# step 2: feature distribution
tmpZ = global_descriptors.matmul(attention_vectors) #b,c_m,h*w
tmpZ=tmpZ.view(b,self.c_m,h,w) #b,c_m,h,w
if self.reconstruct:
tmpZ=self.conv_reconstruct(tmpZ)
return tmpZ
18、GC(Global Context )🌱🌱
论文地址: https://arxiv.org/abs/1904.11492
Global Context (GC) 注意力机制通过捕捉特征图中的全局上下文信息,来提升模型对长距离依赖关系的建模能力。GC注意力机制在每个位置处计算全局上下文向量,并利用该向量对原始特征进行加权,增强特征表示。GC注意力机制的核心思想是通过全局上下文信息对每个位置的特征进行加权,增强特征图的表达能力。它结合了非局部操作的全局建模能力和SE模块的通道注意力机制,使得模型能够更好地捕捉到全局和局部的特征相关性
import torch
from torch import nn as nn
import torch.nn.functional as F
from timm.models.layers.create_act import create_act_layer, get_act_layer
from timm.models.layers.helpers import make_divisible
from timm.models.layers.mlp import ConvMlp
from timm.models.layers.norm import LayerNorm2d
class GlobalContext(nn.Module):
def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
super(GlobalContext, self).__init__()
act_layer = get_act_layer(act_layer)
self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
if rd_channels is None:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
if fuse_add:
self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_add = None
if fuse_scale:
self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_scale = None
self.gate = create_act_layer(gate_layer)
self.init_last_zero = init_last_zero
self.reset_parameters()
def reset_parameters(self):
if self.conv_attn is not None:
nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
if self.mlp_add is not None:
nn.init.zeros_(self.mlp_add.fc2.weight)
def forward(self, x):
B, C, H, W = x.shape
if self.conv_attn is not None:
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
context = context.view(B, C, 1, 1)
else:
context = x.mean(dim=(2, 3), keepdim=True)
if self.mlp_scale is not None:
mlp_x = self.mlp_scale(context)
x = x * self.gate(mlp_x)
if self.mlp_add is not None:
mlp_x = self.mlp_add(context)
x = x + mlp_x
return x
19、 EffectiveSE(Effective Squeeze-Excitation)🌱🌱
论文地址:https://arxiv.org/abs/1911.06667
ESE注意力机制的核心思想是通过对特征图的通道维度进行加权调整,增强重要特征,同时抑制不重要的特征。ESE模块简化了SE模块的全连接层结构,使其更加轻量化和高效。ESE模块能够在多种视觉任务中提升卷积神经网络的性能,特别适合需要高效计算的应用场景。
import torch
from torch import nn as nn
from timm.models.layers.create_act import create_act_layer
class EffectiveSEModule(nn.Module):
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.add_maxpool = add_maxpool
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)
20、Criss-Cross Attention🌱🌱
论文地址:https://arxiv.org/pdf/1811.11721
Criss-Cross Attention是一种在语义分割任务中用于捕获图像全局上下文信息的机制。旨在通过一种新颖的方式有效且高效地聚合全图像的上下文信息。其核心思想是,对于图像中的每个像素点,通过一个特殊的注意力模块来收集与其在水平和垂直方向上交叉路径上所有像素的上下文信息。这种设计显著减少了计算复杂度,并且降低了对GPU内存的需求。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax
def INF(B, H, W, device):
# Create an infinite diagonal tensor on the specified device
return (
-torch.diag(torch.tensor(float("inf"), device=device).repeat(H), 0)
.unsqueeze(0)
.repeat(B * W, 1, 1)
)
class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module"""
def __init__(self, in_dim):
super(CrissCrossAttention, self).__init__()
self.query_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1
)
self.key_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1
)
self.value_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim, kernel_size=1
)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
device = x.device
self.to(device)
m_batchsize, _, height, width = x.size()
proj_query = self.query_conv(x)
proj_query_H = (
proj_query.permute(0, 3, 1, 2)
.contiguous()
.view(m_batchsize * width, -1, height)
.permute(0, 2, 1)
)
proj_query_W = (
proj_query.permute(0, 2, 1, 3)
.contiguous()
.view(m_batchsize * height, -1, width)
.permute(0, 2, 1)
)
proj_key = self.key_conv(x)
proj_key_H = (
proj_key.permute(0, 3, 1, 2)
.contiguous()
.view(m_batchsize * width, -1, height)
)
proj_key_W = (
proj_key.permute(0, 2, 1, 3)
.contiguous()
.view(m_batchsize * height, -1, width)
)
proj_value = self.value_conv(x)
proj_value_H = (
proj_value.permute(0, 3, 1, 2)
.contiguous()
.view(m_batchsize * width, -1, height)
)
proj_value_W = (
proj_value.permute(0, 2, 1, 3)
.contiguous()
.view(m_batchsize * height, -1, width)
)
energy_H = (
(
torch.bmm(proj_query_H, proj_key_H)
+ self.INF(m_batchsize, height, width, device)
)
.view(m_batchsize, width, height, height)
.permute(0, 2, 1, 3)
)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(
m_batchsize, height, width, width
)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))
att_H = (
concate[:, :, :, 0:height]
.permute(0, 2, 1, 3)
.contiguous()
.view(m_batchsize * width, height, height)
)
# print(concate)
# print(att_H)
att_W = (
concate[:, :, :, height : height + width]
.contiguous()
.view(m_batchsize * height, width, width)
)
out_H = (
torch.bmm(proj_value_H, att_H.permute(0, 2, 1))
.view(m_batchsize, width, -1, height)
.permute(0, 2, 3, 1)
)
out_W = (
torch.bmm(proj_value_W, att_W.permute(0, 2, 1))
.view(m_batchsize, height, -1, width)
.permute(0, 2, 1, 3)
)
# print(out_H.size(),out_W.size())
return self.gamma * (out_H + out_W) + x
21、 GE(Gather-Excite Attention )🌱🌱
论文地址: https://arxiv.org/abs/1810.12348
Gather-Excite(GE)注意力机制在《Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks》一文中提出,该注意力机制旨在通过重新校准卷积神经网络(CNN)中的特征图来增强特征的表达能力,从而提高模型的性能。
import math
from timm.layers.create_act import create_act_layer, get_act_layer
from timm.layers.create_conv2d import create_conv2d
from timm.layers.helpers import make_divisible
from timm.layers.mlp import ConvMlp
class GatherExcite(nn.Module):
def __init__(
self,
channels,
feat_size=None,
extra_params=False,
extent=0,
use_mlp=True,
rd_ratio=1.0 / 16,
rd_channels=None,
rd_divisor=1,
add_maxpool=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
gate_layer="sigmoid",
):
super(GatherExcite, self).__init__()
self.add_maxpool = add_maxpool
act_layer = get_act_layer(act_layer)
self.extent = extent
if extra_params:
self.gather = nn.Sequential()
if extent == 0:
assert (
feat_size is not None
), "spatial feature size must be specified for global extent w/ params"
self.gather.add_module(
"conv1",
create_conv2d(
channels,
channels,
kernel_size=feat_size,
stride=1,
depthwise=True,
),
)
if norm_layer:
self.gather.add_module(f"norm1", nn.BatchNorm2d(channels))
else:
assert extent % 2 == 0
num_conv = int(math.log2(extent))
for i in range(num_conv):
self.gather.add_module(
f"conv{i + 1}",
create_conv2d(
channels, channels, kernel_size=3, stride=2, depthwise=True
),
)
if norm_layer:
self.gather.add_module(f"norm{i + 1}", nn.BatchNorm2d(channels))
if i != num_conv - 1:
self.gather.add_module(f"act{i + 1}", act_layer(inplace=True))
else:
self.gather = None
if self.extent == 0:
self.gk = 0
self.gs = 0
else:
assert extent % 2 == 0
self.gk = self.extent * 2 - 1
self.gs = self.extent
if not rd_channels:
rd_channels = make_divisible(
channels * rd_ratio, rd_divisor, round_limit=0.0
)
self.mlp = (
ConvMlp(channels, rd_channels, act_layer=act_layer)
if use_mlp
else nn.Identity()
)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
size = x.shape[-2:]
if self.gather is not None:
x_ge = self.gather(x)
else:
if self.extent == 0:
# global extent
x_ge = x.mean(dim=(2, 3), keepdims=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
else:
x_ge = F.avg_pool2d(
x,
kernel_size=self.gk,
stride=self.gs,
padding=self.gk // 2,
count_include_pad=False,
)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(
x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2
)
x_ge = self.mlp(x_ge)
if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
x_ge = F.interpolate(x_ge, size=size)
return x * self.gate(x_ge)
22、PSA (Polarized Self-Attention)🌱🌱
论文地址:https://arxiv.org/abs/2107.00782
Polarized Self-Attention(PSA)旨在提高像素级回归任务的性能,如图像分割、深度估计等。通过引入极化步骤,有效地分离和处理内容和位置特征,提高了像素级回归任务的性能和效率。这种机制不仅提高了模型的准确性,还显著降低了计算复杂度,适用于高分辨率图像处理任务。
优点:
-
有助于提升模型准确性:
PSA通过有效地分离和处理内容和位置特征,显著提高了像素级回归任务的准确性。
实验表明,PSA在多个基准数据集上都取得了优异的性能。 -
减少计算复杂度:
PSA通过引入极化机制,减少了传统自注意机制的计算负担,使得模型在处理高分辨率图像时更加高效。 -
易于集成:
PSA机制易于与现有的深度学习模型集成,无需对现有模型进行大幅度修改。
class ParallelPolarizedSelfAttention(nn.Module):
def __init__(self, channel=512):
super().__init__()
self.ch_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
self.ch_wq = nn.Conv2d(channel, 1, kernel_size=(1, 1))
self.softmax_channel = nn.Softmax(1)
self.softmax_spatial = nn.Softmax(-1)
self.ch_wz = nn.Conv2d(channel // 2, channel, kernel_size=(1, 1))
self.ln = nn.LayerNorm(channel)
self.sigmoid = nn.Sigmoid()
self.sp_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
self.sp_wq = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
self.agp = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
b, c, h, w = x.size()
# Channel-only Self-Attention
channel_wv = self.ch_wv(x) # bs,c//2,h,w
channel_wq = self.ch_wq(x) # bs,1,h,w
channel_wv = channel_wv.reshape(b, c // 2, -1) # bs,c//2,h*w
channel_wq = channel_wq.reshape(b, -1, 1) # bs,h*w,1
channel_wq = self.softmax_channel(channel_wq)
channel_wz = torch.matmul(channel_wv, channel_wq).unsqueeze(-1) # bs,c//2,1,1
channel_weight = (
self.sigmoid(
self.ln(self.ch_wz(channel_wz).reshape(b, c, 1).permute(0, 2, 1))
)
.permute(0, 2, 1)
.reshape(b, c, 1, 1)
) # bs,c,1,1
channel_out = channel_weight * x
# Spatial-only Self-Attention
spatial_wv = self.sp_wv(x) # bs,c//2,h,w
spatial_wq = self.sp_wq(x) # bs,c//2,h,w
spatial_wq = self.agp(spatial_wq) # bs,c//2,1,1
spatial_wv = spatial_wv.reshape(b, c // 2, -1) # bs,c//2,h*w
spatial_wq = spatial_wq.permute(0, 2, 3, 1).reshape(b, 1, c // 2) # bs,1,c//2
spatial_wq = self.softmax_spatial(spatial_wq)
spatial_wz = torch.matmul(spatial_wq, spatial_wv) # bs,1,h*w
spatial_weight = self.sigmoid(spatial_wz.reshape(b, 1, h, w)) # bs,1,h,w
spatial_out = spatial_weight * x
out = spatial_out + channel_out
return out
23、Sequential Self-Attention🌱🌱
论文地址:https://arxiv.org/abs/2107.00782
Sequential Self-Attention和Polarized Self-Attention出自同一篇文章,区别在于模块的布局方式,前者为顺序布局,后者为平行布局。
class SequentialPolarizedSelfAttention(nn.Module):
def __init__(self, channel=512):
super().__init__()
self.ch_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
self.ch_wq = nn.Conv2d(channel, 1, kernel_size=(1, 1))
self.softmax_channel = nn.Softmax(1)
self.softmax_spatial = nn.Softmax(-1)
self.ch_wz = nn.Conv2d(channel // 2, channel, kernel_size=(1, 1))
self.ln = nn.LayerNorm(channel)
self.sigmoid = nn.Sigmoid()
self.sp_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
self.sp_wq = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
self.agp = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
b, c, h, w = x.size()
# Channel-only Self-Attention
channel_wv = self.ch_wv(x) # bs,c//2,h,w
channel_wq = self.ch_wq(x) # bs,1,h,w
channel_wv = channel_wv.reshape(b, c // 2, -1) # bs,c//2,h*w
channel_wq = channel_wq.reshape(b, -1, 1) # bs,h*w,1
channel_wq = self.softmax_channel(channel_wq)
channel_wz = torch.matmul(channel_wv, channel_wq).unsqueeze(-1) # bs,c//2,1,1
channel_weight = (
self.sigmoid(
self.ln(self.ch_wz(channel_wz).reshape(b, c, 1).permute(0, 2, 1))
)
.permute(0, 2, 1)
.reshape(b, c, 1, 1)
) # bs,c,1,1
channel_out = channel_weight * x
# Spatial-only Self-Attention
spatial_wv = self.sp_wv(channel_out) # bs,c//2,h,w
spatial_wq = self.sp_wq(channel_out) # bs,c//2,h,w
spatial_wq = self.agp(spatial_wq) # bs,c//2,1,1
spatial_wv = spatial_wv.reshape(b, c // 2, -1) # bs,c//2,h*w
spatial_wq = spatial_wq.permute(0, 2, 3, 1).reshape(b, 1, c // 2) # bs,1,c//2
spatial_wq = self.softmax_spatial(spatial_wq)
spatial_wz = torch.matmul(spatial_wq, spatial_wv) # bs,1,h*w
spatial_weight = self.sigmoid(spatial_wz.reshape(b, 1, h, w)) # bs,1,h,w
spatial_out = spatial_weight * channel_out
return spatial_out
24、GAM(Global Attention Mechanism)🌱🌱
论文地址:https://arxiv.org/pdf/2112.05561v1
Global Attention Mechanism(GAM)旨在增强通道与空间之间的交互,保留更多信息,从而提升图像处理任务的性能。该机制通过引入双重注意力机制,有效地提高了图像处理任务的性能。GAM为未来的计算机视觉任务提供了一种强大且高效的解决方案,能够在多个应用场景中取得优异的表现。
- 增强信息保留:
通过全局范围的注意力机制,GAM能够保留更多的信息,提高模型对图像细节的理解能力。
这种机制有助于模型更好地捕捉到图像中的重要特征,提升整体性能。 - 通道和空间的全面交互:
GAM同时关注通道和空间之间的交互,提供了更全面的特征表示。
这种双重注意力机制使得模型能够更准确地捕捉到图像中的复杂结构和细节。 - 提高模型性能:
实验表明,GAM在多个基准数据集上的表现优异,有助于提高各种图像处理任务的性能,如图像分类、对象检测和分割等。
class GAMAttention(nn.Module):
def __init__(self, c1, c2, group=True, rate=4):
super(GAMAttention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(c1, int(c1 / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(c1 / rate), c1),
)
self.spatial_attention = nn.Sequential(
(
nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate)
if group
else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3)
),
nn.BatchNorm2d(int(c1 / rate)),
nn.ReLU(inplace=True),
(
nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate)
if group
else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3)
),
nn.BatchNorm2d(c2),
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
x_spatial_att = channel_shuffle(x_spatial_att, 4) # last shuffle
out = x * x_spatial_att
return out
def channel_shuffle(x, groups=2): ##shuffle channel
# RESHAPE----->transpose------->Flatten
B, C, H, W = x.size()
out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
out = out.view(B, C, H, W)
return out
25、Biformer🌱🌱
论文地址: https://arxiv.org/abs/2303.08810
BiFormer通过引入双层路由注意力机制,有效地提高了视觉Transformer的性能和效率。该机制通过分层计算全局和局部注意力,使得模型能够更好地捕捉图像中的重要特征,同时减少计算复杂度。BiFormer为未来的计算机视觉任务提供了一种强大且高效的解决方案,适用于各种图像处理应用场景。
优点:
- 提高准确性:
通过双层路由注意力机制,BiFormer能够更准确地捕捉到图像中的全局和局部特征,提高模型的整体性能。
这种方法在多个基准数据集上取得了优异的性能,表明其在实际应用中的有效性。 - 减少计算复杂度:
BiFormer通过分层计算注意力,显著减少了计算复杂度,使其在处理高分辨率图像时更加高效。
这种机制确保了模型的高效运行,适用于资源受限的环境。 - 易于集成和扩展:
BiFormer易于与现有的视觉Transformer模型集成,无需对现有架构进行大幅度修改。
这种灵活性使得BiFormer能够广泛应用于不同的图像处理任务。
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
class TopkRouting(nn.Module):
"""
differentiable topk routing with scaling
Args:
qk_dim: int, feature dimension of query and key
topk: int, the 'topk'
qk_scale: int or None, temperature (multiply) of softmax activation
with_param: bool, wether inorporate learnable params in routing unit
diff_routing: bool, wether make routing differentiable
soft_routing: bool, wether make output value multiplied by routing weights
"""
def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
super().__init__()
self.topk = topk
self.qk_dim = qk_dim
self.scale = qk_scale or qk_dim ** -0.5
self.diff_routing = diff_routing
# TODO: norm layer before/after linear?
self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
# routing activation
self.routing_act = nn.Softmax(dim=-1)
def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:
"""
Args:
q, k: (n, p^2, c) tensor
Return:
r_weight, topk_index: (n, p^2, topk) tensor
"""
if not self.diff_routing:
query, key = query.detach(), key.detach()
query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)
attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
return r_weight, topk_index
class KVGather(nn.Module):
def __init__(self, mul_weight='none'):
super().__init__()
assert mul_weight in ['none', 'soft', 'hard']
self.mul_weight = mul_weight
def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):
"""
r_idx: (n, p^2, topk) tensor
r_weight: (n, p^2, topk) tensor
kv: (n, p^2, w^2, c_kq+c_v)
Return:
(n, p^2, topk, w^2, c_kq+c_v) tensor
"""
# select kv according to routing index
n, p2, w2, c_kv = kv.size()
topk = r_idx.size(-1)
# print(r_idx.size(), r_weight.size())
# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy
dim=2,
index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv)
)
if self.mul_weight == 'soft':
topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
elif self.mul_weight == 'hard':
raise NotImplementedError('differentiable hard routing TBA')
# else: #'none'
# topk_kv = topk_kv # do nothing
return topk_kv
class QKVLinear(nn.Module):
def __init__(self, dim, qk_dim, bias=True):
super().__init__()
self.dim = dim
self.qk_dim = qk_dim
self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
def forward(self, x):
q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)
return q, kv
# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
# return q, k, v
class BiLevelRoutingAttention(nn.Module):
"""
n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
topk: topk for window filtering
param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
param_routing: extra linear for routing
diff_routing: wether to set routing differentiable
soft_routing: wether to multiply soft routing weights
"""
def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None,
kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
auto_pad=True):
super().__init__()
# local attention setting
self.dim = dim
self.n_win = n_win # Wh, Ww
self.num_heads = num_heads
self.qk_dim = qk_dim or dim
assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
self.scale = qk_scale or self.qk_dim ** -0.5
################side_dwconv (i.e. LCE in ShuntedTransformer)###########
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
lambda x: torch.zeros_like(x)
################ global routing setting #################
self.topk = topk
self.param_routing = param_routing
self.diff_routing = diff_routing
self.soft_routing = soft_routing
# router
assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
self.router = TopkRouting(qk_dim=self.qk_dim,
qk_scale=self.scale,
topk=self.topk,
diff_routing=self.diff_routing,
param_routing=self.param_routing)
if self.soft_routing: # soft routing, always diffrentiable (if no detach)
mul_weight = 'soft'
elif self.diff_routing: # hard differentiable routing
mul_weight = 'hard'
else: # hard non-differentiable routing
mul_weight = 'none'
self.kv_gather = KVGather(mul_weight=mul_weight)
# qkv mapping (shared by both global routing and local attention)
self.param_attention = param_attention
if self.param_attention == 'qkvo':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Linear(dim, dim)
elif self.param_attention == 'qkv':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Identity()
else:
raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
self.kv_downsample_mode = kv_downsample_mode
self.kv_per_win = kv_per_win
self.kv_downsample_ratio = kv_downsample_ratio
self.kv_downsample_kenel = kv_downsample_kernel
if self.kv_downsample_mode == 'ada_avgpool':
assert self.kv_per_win is not None
self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
elif self.kv_downsample_mode == 'ada_maxpool':
assert self.kv_per_win is not None
self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
elif self.kv_downsample_mode == 'maxpool':
assert self.kv_downsample_ratio is not None
self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
elif self.kv_downsample_mode == 'avgpool':
assert self.kv_downsample_ratio is not None
self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
elif self.kv_downsample_mode == 'identity': # no kv downsampling
self.kv_down = nn.Identity()
elif self.kv_downsample_mode == 'fracpool':
# assert self.kv_downsample_ratio is not None
# assert self.kv_downsample_kenel is not None
# TODO: fracpool
# 1. kernel size should be input size dependent
# 2. there is a random factor, need to avoid independent sampling for k and v
raise NotImplementedError('fracpool policy is not implemented yet!')
elif kv_downsample_mode == 'conv':
# TODO: need to consider the case where k != v so that need two downsample modules
raise NotImplementedError('conv policy is not implemented yet!')
else:
raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')
# softmax for local attention
self.attn_act = nn.Softmax(dim=-1)
self.auto_pad=auto_pad
def forward(self, x, ret_attn_mask=False):
"""
x: NHWC tensor
Return:
NHWC tensor
"""
x = rearrange(x, "n c h w -> n h w c")
# NOTE: use padding for semantic segmentation
###################################################
if self.auto_pad:
N, H_in, W_in, C = x.size()
pad_l = pad_t = 0
pad_r = (self.n_win - W_in % self.n_win) % self.n_win
pad_b = (self.n_win - H_in % self.n_win) % self.n_win
x = F.pad(x, (0, 0, # dim=-1
pad_l, pad_r, # dim=-2
pad_t, pad_b)) # dim=-3
_, H, W, _ = x.size() # padded size
else:
N, H, W, C = x.size()
assert H%self.n_win == 0 and W%self.n_win == 0 #
###################################################
# patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
#################qkv projection###################
# q: (n, p^2, w, w, c_qk)
# kv: (n, p^2, w, w, c_qk+c_v)
# NOTE: separte kv if there were memory leak issue caused by gather
q, kv = self.qkv(x)
# pixel-wise qkv
# q_pix: (n, p^2, w^2, c_qk)
# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)
##################side_dwconv(lepe)##################
# NOTE: call contiguous to avoid gradient warning when using ddp
lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())
lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
############ gather q dependent k/v #################
r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors
kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)
k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
######### do attention as normal ####################
k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)
# param-free multihead attention
attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
attn_weight = self.attn_act(attn_weight)
out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
h=H//self.n_win, w=W//self.n_win)
out = out + lepe
# output linear
out = self.wo(out)
# NOTE: use padding for semantic segmentation
# crop padded region
if self.auto_pad and (pad_r > 0 or pad_b > 0):
out = out[:, :H_in, :W_in, :].contiguous()
if ret_attn_mask:
return out, r_weight, r_idx, attn_weight
else:
return rearrange(out, "n h w c -> n c h w")
class Attention(nn.Module):
"""
vanilla attention
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
"""
args:
x: NCHW tensor
return:
NCHW tensor
"""
_, _, H, W = x.size()
x = rearrange(x, 'n c h w -> n (h w) c')
#######################################
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
#######################################
x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
return x
class AttentionLePE(nn.Module):
"""
vanilla attention
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
lambda x: torch.zeros_like(x)
def forward(self, x):
"""
args:
x: NCHW tensor
return:
NCHW tensor
"""
_, _, H, W = x.size()
x = rearrange(x, 'n c h w -> n (h w) c')
#######################################
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W))
lepe = rearrange(lepe, 'n c h w -> n (h w) c')
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = x + lepe
x = self.proj(x)
x = self.proj_drop(x)
#######################################
x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
return x
注:注意力机制的添加方式见文章,添加位置不固定,可根据自己的需求灵活调整。感谢大家的支持和关注❤❤
本文至此结束,文章持续更新中,敬请期待!!!
喜欢的本文的话,请不吝点赞+收藏,感谢大家的支持🍵🍵