本文接续上文介绍YOLOv10如何添加先进注意力机制 。超30种注意力机制模块,助力文章涨点多多🚀🚀
文章目录
提示:喜欢本专栏的小伙伴,请多多点赞关注支持。本文仅供学习交流使用,创作不易,未经作者允许,不得搬运或转载!!!
注意力机制模块介绍🛩️🛩️
26、Axial Attention🌱🌱
Axial Attention 是一种高效的注意力机制,专门设计用于多维数据如图像和视频。它通过分解高维注意力操作,分别在各个轴上执行,从而减少计算复杂度并提高模型的效率。
Axial Attention 的核心思想是将多维注意力分解为一系列一维的注意力操作,这样可以显著降低计算成本,同时保持全局上下文信息的捕获能力。对于一个二维特征图(例如图像),Axial Attention 分别在水平轴和垂直轴上计算注意力。
import torch
from torch import nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# once multi-GPU is confirmed working, refactor and send PR back to source
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args={}, g_args={}):
x1, x2 = torch.chunk(x, 2, dim=1)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=1)
def backward_pass(self, y, dy, f_args={}, g_args={}):
y1, y2 = torch.chunk(y, 2, dim=1)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=1)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=1)
dx = torch.cat([dx1, dx2], dim=1)
return x, dx
class IrreversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = f
self.g = g
def forward(self, x, f_args, g_args):
x1, x2 = torch.chunk(x, 2, dim=1)
y1 = x1 + self.f(x2, **f_args)
y2 = x2 + self.g(y1, **g_args)
return torch.cat([y1, y2], dim=1)
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, kwargs):
ctx.kwargs = kwargs
for block in blocks:
x = block(x, **kwargs)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
kwargs = ctx.kwargs
for block in ctx.blocks[::-1]:
y, dy = block.backward_pass(y, dy, **kwargs)
return dy, None, None
class ReversibleSequence(nn.Module):
def __init__(
self,
blocks,
):
super().__init__()
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
def forward(self, x, arg_route=(True, True), **kwargs):
f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
block_kwargs = {"f_args": f_args, "g_args": g_args}
x = torch.cat((x, x), dim=1)
x = _ReversibleFunction.apply(x, self.blocks, block_kwargs)
return torch.stack(x.chunk(2, dim=1)).mean(dim=0)
# helper functions
def exists(val):
return val is not None
def map_el_ind(arr, ind):
return list(map(itemgetter(ind), arr))
def sort_and_return_indices(arr):
indices = [ind for ind in range(len(arr))]
arr = zip(arr, indices)
arr = sorted(arr)
return map_el_ind(arr, 0), map_el_ind(arr, 1)
# calculates the permutation to bring the input tensor to something attend-able
# also calculates the inverse permutation to bring the tensor back to its original shape
def calculate_permutations(num_dimensions, emb_dim):
total_dimensions = num_dimensions + 2
emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)
axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]
permutations = []
for axial_dim in axial_dims:
last_two_dims = [axial_dim, emb_dim]
dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
permutation = [*dims_rest, *last_two_dims]
permutations.append(permutation)
return permutations
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class Sequential(nn.Module):
def __init__(self, blocks):
super().__init__()
self.blocks = blocks
def forward(self, x):
for f, g in self.blocks:
x = x + f(x)
x = x + g(x)
return x
class PermuteToFrom(nn.Module):
def __init__(self, permutation, fn):
super().__init__()
self.fn = fn
_, inv_permutation = sort_and_return_indices(permutation)
self.permutation = permutation
self.inv_permutation = inv_permutation
def forward(self, x, **kwargs):
axial = x.permute(*self.permutation).contiguous()
shape = axial.shape
*_, t, d = shape
# merge all but axial dimension
axial = axial.reshape(-1, t, d)
# attention
axial = self.fn(axial, **kwargs)
# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
return axial
# axial pos emb
class AxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape, emb_dim_index=1):
super().__init__()
parameters = []
total_dimensions = len(shape) + 2
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
self.num_axials = len(shape)
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
shape = [1] * total_dimensions
shape[emb_dim_index] = dim
shape[axial_dim_index] = axial_dim
parameter = nn.Parameter(torch.randn(*shape))
setattr(self, f"param_{i}", parameter)
def forward(self, x):
for i in range(self.num_axials):
x = x + getattr(self, f"param_{i}")
return x
# attention
class SelfAttention(nn.Module):
def __init__(self, dim, heads, dim_heads=None):
super().__init__()
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
dim_hidden = self.dim_heads * heads
self.heads = heads
self.to_q = nn.Linear(dim, dim_hidden, bias=False)
self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x, kv=None):
kv = x if kv is None else kv
q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))
b, t, d, h, e = *q.shape, self.heads, self.dim_heads
merge_heads = (
lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
)
q, k, v = map(merge_heads, (q, k, v))
dots = torch.einsum("bie,bje->bij", q, k) * (e**-0.5)
dots = dots.softmax(dim=-1)
out = torch.einsum("bij,bje->bie", dots, v)
out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
out = self.to_out(out)
return out
# axial attention class
class AxialAttention(nn.Module):
def __init__(
self,
dim,
num_dimensions=2,
heads=8,
dim_heads=None,
dim_index=-1,
sum_axial_out=True,
):
assert (
dim % heads
) == 0, "hidden dimension must be divisible by number of heads"
super().__init__()
self.dim = dim
self.total_dimensions = num_dimensions + 2
self.dim_index = (
dim_index if dim_index > 0 else (dim_index + self.total_dimensions)
)
attentions = []
for permutation in calculate_permutations(num_dimensions, dim_index):
attentions.append(
PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads))
)
self.axial_attentions = nn.ModuleList(attentions)
self.sum_axial_out = sum_axial_out
def forward(self, x):
assert (
len(x.shape) == self.total_dimensions
), "input tensor does not have the correct number of dimensions"
assert (
x.shape[self.dim_index] == self.dim
), "input tensor does not have the correct input dimension"
if self.sum_axial_out:
return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))
out = x
for axial_attn in self.axial_attentions:
out = axial_attn(out)
return out
27、HaloAttention🌱🌱
Halo Attention 是一种提出用于图像和视觉任务的高效自注意力机制,旨在提升参数利用效率和计算效率。它通过局部自注意力的方式,在保持全局信息的同时减少计算成本。
Halo Attention 的核心思想是将输入特征图划分为多个局部块(blocks),然后在每个块内进行自注意力计算。这种方法不仅能够捕捉局部特征,还通过引入边界区域(halo regions)来保留全局上下文信息。
import torch
from torch import nn, einsum
import torch.nn.functional as
from einops import rearrange, repeat
def to(x):
return {"device": x.device, "dtype": x.dtype}
def pair(x):
return (x, x) if not isinstance(x, tuple) else x
def expand_dim(t, dim, k):
t = t.unsqueeze(dim=dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
def rel_to_abs(x):
b, l, m = x.shape
r = (m + 1) // 2
col_pad = torch.zeros((b, l, 1), **to(x))
x = torch.cat((x, col_pad), dim=2)
flat_x = rearrange(x, "b l c -> b (l c)")
flat_pad = torch.zeros((b, m - l), **to(x))
flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)
final_x = flat_x_padded.reshape(b, l + 1, m)
final_x = final_x[:, :l, -r:]
return final_x
def relative_logits_1d(q, rel_k):
b, h, w, _ = q.shape
r = (rel_k.shape[0] + 1) // 2
logits = einsum("b x y d, r d -> b x y r", q, rel_k)
logits = rearrange(logits, "b x y r -> (b x) y r")
logits = rel_to_abs(logits)
logits = logits.reshape(b, h, w, r)
logits = expand_dim(logits, dim=2, k=r)
return logits
class RelPosEmb(nn.Module):
def __init__(self, block_size, rel_size, dim_head):
super().__init__()
height = width = rel_size
scale = dim_head**-0.5
self.block_size = block_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
def forward(self, q):
block = self.block_size
q = rearrange(q, "b (x y) c -> b x y c", x=block)
rel_logits_w = relative_logits_1d(q, self.rel_width)
rel_logits_w = rearrange(rel_logits_w, "b x i y j-> b (x y) (i j)")
q = rearrange(q, "b x y d -> b y x d")
rel_logits_h = relative_logits_1d(q, self.rel_height)
rel_logits_h = rearrange(rel_logits_h, "b x i y j -> b (y x) (j i)")
return rel_logits_w + rel_logits_h
class HaloAttention(nn.Module):
def __init__(self, dim, block_size, halo_size, dim_head=64, heads=8):
super().__init__()
assert halo_size > 0, "halo size must be greater than 0"
self.dim = dim
self.heads = heads
self.scale = dim_head**-0.5
self.block_size = block_size
self.halo_size = halo_size
inner_dim = dim_head * heads
self.rel_pos_emb = RelPosEmb(
block_size=block_size,
rel_size=block_size + (halo_size * 2),
dim_head=dim_head,
)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
b, c, h, w, block, halo, heads, device = (
*x.shape,
self.block_size,
self.halo_size,
self.heads,
x.device,
)
assert (
h % block == 0 and w % block == 0
), "fmap dimensions must be divisible by the block size"
assert (
c == self.dim
), f"channels for input ({c}) does not equal to the correct dimension ({self.dim})"
# get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values
q_inp = rearrange(
x, "b c (h p1) (w p2) -> (b h w) (p1 p2) c", p1=block, p2=block
)
kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo)
kv_inp = rearrange(kv_inp, "b (c j) i -> (b i) j c", c=c)
# derive queries, keys, values
q = self.to_q(q_inp)
k, v = self.to_kv(kv_inp).chunk(2, dim=-1)
# split heads
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=heads), (q, k, v)
)
# scale
q *= self.scale
# attention
sim = einsum("b i d, b j d -> b i j", q, k)
# add relative positional bias
sim += self.rel_pos_emb(q)
# mask out padding (in the paper, they claim to not need masks, but what about padding?)
mask = torch.ones(1, 1, h, w, device=device)
mask = F.unfold(
mask, kernel_size=block + (halo * 2), stride=block, padding=halo
)
mask = repeat(mask, "() j i -> (b i h) () j", b=b, h=heads)
mask = mask.bool()
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(mask, max_neg_value)
# attention
attn = sim.softmax(dim=-1)
# aggregate
out = einsum("b i j, b j d -> b i d", attn, v)
# merge and combine heads
out = rearrange(out, "(b h) n d -> b n (h d)", h=heads)
out = self.to_out(out)
# merge blocks back to original feature map
out = rearrange(
out,
"(b h w) (p1 p2) c -> b c (h p1) (w p2)",
b=b,
h=(h // block),
w=(w // block),
p1=block,
p2=block,
)
return out
28、iRMB(Inverted Residual Mobile Block)🌱🌱
iRMB主要用于增强注意力模型的效率,特别是针对移动设备的约束。iRMB模块结合了反向残差结构和轻量级注意力机制,优化了计算效率和参数利用率。
iRMB模块设计的核心是反向残差结构,这种结构通过从低维空间到高维空间的映射,再回到低维空间的过程来减少计算量。iRMB还集成了通道注意力和空间注意力机制,以提高模型对特征的关注能力和表示能力。
import math
from functools import partial
from timm.models.efficientnet_blocks import SqueezeExcite as SE
from einops import rearrange, reduce
from timm.models.layers.activations import *
from timm.models.layers import DropPath
inplace = True
# SE
class SE(nn.Module):
def __init__(self, c1, ratio=16):
super(SE, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
self.relu = nn.ReLU(inplace=True)
self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
self.sig = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.l1(y)
y = self.relu(y)
y = self.l2(y)
y = self.sig(y)
y = y.view(b, c, 1, 1)
return x * y.expand_as(x)
class LayerNorm2d(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
super().__init__()
self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c').contiguous()
x = self.norm(x)
x = rearrange(x, 'b h w c -> b c h w').contiguous()
return x
def get_norm(norm_layer='in_1d'):
eps = 1e-6
norm_dict = {
'none': nn.Identity,
'in_1d': partial(nn.InstanceNorm1d, eps=eps),
'in_2d': partial(nn.InstanceNorm2d, eps=eps),
'in_3d': partial(nn.InstanceNorm3d, eps=eps),
'bn_1d': partial(nn.BatchNorm1d, eps=eps),
'bn_2d': partial(nn.BatchNorm2d, eps=eps),
# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
'bn_3d': partial(nn.BatchNorm3d, eps=eps),
'gn': partial(nn.GroupNorm, eps=eps),
'ln_1d': partial(nn.LayerNorm, eps=eps),
'ln_2d': partial(LayerNorm2d, eps=eps),
}
return norm_dict[norm_layer]
def get_act(act_layer='relu'):
act_dict = {
'none': nn.Identity,
'sigmoid': Sigmoid,
'swish': Swish,
'mish': Mish,
'hsigmoid': HardSigmoid,
'hswish': HardSwish,
'hmish': HardMish,
'tanh': Tanh,
'relu': nn.ReLU,
'relu6': nn.ReLU6,
'prelu': PReLU,
'gelu': GELU,
'silu': nn.SiLU
}
return act_dict[act_layer]
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=True):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(1, 1, dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class LayerScale2D(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=True):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(1, dim, 1, 1))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class ConvNormAct(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
super(ConvNormAct, self).__init__()
self.has_skip = skip and dim_in == dim_out
padding = math.ceil((kernel_size - stride) / 2)
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
self.norm = get_norm(norm_layer)(dim_out)
self.act = get_act(act_layer)(inplace=inplace)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
if self.has_skip:
x = self.drop_path(x) + shortcut
return x
# ========== Multi-Scale Populations, for down-sampling and inductive bias ==========
class MSPatchEmb(nn.Module):
def __init__(self, dim_in, emb_dim, kernel_size=2, c_group=-1, stride=1, dilations=[1, 2, 3],
norm_layer='bn_2d', act_layer='silu'):
super().__init__()
self.dilation_num = len(dilations)
assert dim_in % c_group == 0
c_group = math.gcd(dim_in, emb_dim) if c_group == -1 else c_group
self.convs = nn.ModuleList()
for i in range(len(dilations)):
padding = math.ceil(((kernel_size - 1) * dilations[i] + 1 - stride) / 2)
self.convs.append(nn.Sequential(
nn.Conv2d(dim_in, emb_dim, kernel_size, stride, padding, dilations[i], groups=c_group),
get_norm(norm_layer)(emb_dim),
get_act(act_layer)(emb_dim)))
def forward(self, x):
if self.dilation_num == 1:
x = self.convs[0](x)
else:
x = torch.cat([self.convs[i](x).unsqueeze(dim=-1) for i in range(self.dilation_num)], dim=-1)
x = reduce(x, 'b c h w n -> b c h w', 'mean').contiguous()
return x
class iRMB(nn.Module):
def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',
act_layer='relu', v_proj=True, dw_ks=3, stride=1, dilation=1, se_ratio=0.0, dim_head=64, window_size=7,
attn_s=True, qkv_bias=False, attn_drop=0., drop=0., drop_path=0., v_group=False, attn_pre=False):
super().__init__()
self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
dim_mid = int(dim_in * exp_ratio)
self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
self.attn_s = attn_s
if self.attn_s:
assert dim_in % dim_head == 0, 'dim should be divisible by num_heads'
self.dim_head = dim_head
self.window_size = window_size
self.num_head = dim_in // dim_head
self.scale = self.dim_head ** -0.5
self.attn_pre = attn_pre
self.qk = ConvNormAct(dim_in, int(dim_in * 2), kernel_size=1, bias=qkv_bias, norm_layer='none',
act_layer='none')
self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, groups=self.num_head if v_group else 1, bias=qkv_bias,
norm_layer='none', act_layer=act_layer, inplace=inplace)
self.attn_drop = nn.Dropout(attn_drop)
else:
if v_proj:
self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, bias=qkv_bias, norm_layer='none',
act_layer=act_layer, inplace=inplace)
else:
self.v = nn.Identity()
self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation,
groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)
self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()
self.proj_drop = nn.Dropout(drop)
self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
def forward(self, x):
shortcut = x
x = self.norm(x)
B, C, H, W = x.shape
if self.attn_s:
# padding
if self.window_size <= 0:
window_size_W, window_size_H = W, H
else:
window_size_W, window_size_H = self.window_size, self.window_size
pad_l, pad_t = 0, 0
pad_r = (window_size_W - W % window_size_W) % window_size_W
pad_b = (window_size_H - H % window_size_H) % window_size_H
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b, 0, 0,))
n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W
x = rearrange(x, 'b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1', n1=n1, n2=n2).contiguous()
# attention
b, c, h, w = x.shape
qk = self.qk(x)
qk = rearrange(qk, 'b (qk heads dim_head) h w -> qk b heads (h w) dim_head', qk=2, heads=self.num_head,
dim_head=self.dim_head).contiguous()
q, k = qk[0], qk[1]
attn_spa = (q @ k.transpose(-2, -1)) * self.scale
attn_spa = attn_spa.softmax(dim=-1)
attn_spa = self.attn_drop(attn_spa)
if self.attn_pre:
x = rearrange(x, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
x_spa = attn_spa @ x
x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h,
w=w).contiguous()
x_spa = self.v(x_spa)
else:
v = self.v(x)
v = rearrange(v, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
x_spa = attn_spa @ v
x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h,
w=w).contiguous()
# unpadding
x = rearrange(x_spa, '(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)', n1=n1, n2=n2).contiguous()
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous()
else:
x = self.v(x)
x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))
x = self.proj_drop(x)
x = self.proj(x)
x = (shortcut + self.drop_path(x)) if self.has_skip else x
return x
29、NAM Attention🌱🌱
NAM注意力模块能够在空间和通道维度上对输入特征图进行有效加权,从而提升模型的特征表示能力。该模块利用归一化操作来计算注意力权重,使得模型能够更好地关注于重要特征,提高任务的表现。
import torch.nn as nn
import torch
from torch.nn import functional as F
class Channel_Att(nn.Module):
def __init__(self, channels):
super(Channel_Att, self).__init__()
self.channels = channels
self.bn2 = nn.BatchNorm2d(self.channels, affine=True)
def forward(self, x):
residual = x
x = self.bn2(x)
weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())
x = x.permute(0, 2, 3, 1).contiguous()
x = torch.mul(weight_bn, x)
x = x.permute(0, 3, 1, 2).contiguous()
x = torch.sigmoid(x) * residual #
return x
class NAMAttention(nn.Module):
def __init__(self, channels):
super(NAMAttention, self).__init__()
self.Channel_Att = Channel_Att(channels)
def forward(self, x):
x_out1 = self.Channel_Att(x)
return x_out1
30、DANChannel Attention🌱🌱
DANChannel Attention模块可显式地对通道之间的相互依赖性进行建模,通过利用通道图之间的相互依赖性,可强调相互依赖的特征图,并改进特定语义的特征表示。通道注意力模块能够更好地关注于重要的特征通道,有助于提高场景分割的准确性。
class CAM_Module(nn.Module):
"""Channel attention module"""
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.channel_in = in_dim
# CAM和PAM相比是没有Conv2d层的
self.gamma = nn.Parameter(torch.zeros(1)) #注意此处对$\beta$是如何初始化和使用的
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs:
x : input feature maps
returns:
out:attention value + input feature
attention: B * C * C
"""
m_batchsize, C, height, width = x.size()
proj_query = x.view(m_batchsize, C, -1)
proj_key = x.view(m_batchsize,C,-1).permute(0,2,1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
attention = self.softmax(energy_new)
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(m_batchsize, C, heigh, width)
out = self.gamma*out + x
return out
31、DANPositional Attention🌱🌱
DANPositional Attention可帮助模型在局部特征上建立丰富的上下文关系,能将更广泛的上下文信息编码为局部特征,从而增强其表示能力。
class PAM_Module(nn.Module):
"""Position attention module"""
# Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.channel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8,kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8,kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim,kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1)) #注意此处对$\alpha$是如何初始化和使用的
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs:
x : input feature maps
returns:
out:attention value + input feature
attention: B * (H*W) * (H*W)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0,2,1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query,proj_key)
attetion = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value,attention.permute(0,2,1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
注:注意力机制的添加方式见文章,添加位置不固定,可根据自己的需求灵活调整。感谢大家的支持和关注❤❤
本文至此结束,文章持续更新中,敬请期待!!!
喜欢的本文的话,请不吝点赞+收藏,感谢大家的支持🍵🍵