1,本文介绍
Transformer 模型具备许多适合构建强大数据驱动模型的特性,比如能够捕获数据中的远程依赖关系。与卷积操作的局部特性不同,注意力机制的全局接收场使得视觉 Transformer 能够处理远程依赖。引入稀疏注意力可以减少计算负担。由于不同语义区域的查询会关注不同的键值对,强制所有查询处理相同的令牌组可能不够理想。一种改进的方法是通过每个查询的本地上下文预测注意力偏移量,从而减少计算复杂性。为了更高效地定位有价值的键值,可以使用区域到区域路由策略,这在粗粒度的区域级别过滤掉不相关的键值,而不是在细粒度的令牌级别。使用 BRA(Bidirectional Region Attention)作为核心组件,我们提出了 BiFormer,这是一种通用的视觉 Transformer 骨干。BRA 使 BiFormer 能够以内容感知的方式处理每个查询最相关的键/值令牌,从而在计算性能和效率之间实现更好的平衡。
关于BiFormer的详细介绍可以看论文:https://arxiv.org/pdf/2303.08810.pdf
本文将讲解如何将BiFormer融合进yolov8
话不多说,上代码!
2, 将BiFormer融合进yolov8
2.1 步骤一
找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个BiFormer.py文件,文件名字可以根据你自己的习惯起,然后将BiFormer的核心代码复制进去
###################### BiLevelRoutingAttention #### AI&CV start ###############################
from einops import rearrange
import torch.nn.functional as F
from torch import Tensor
class TopkRouting(nn.Module):
"""
differentiable topk routing with scaling
Args:
qk_dim: int, feature dimension of query and key
topk: int, the 'topk'
qk_scale: int or None, temperature (multiply) of softmax activation
with_param: bool, wether inorporate learnable params in routing unit
diff_routing: bool, wether make routing differentiable
soft_routing: bool, wether make output value multiplied by routing weights
"""
def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
super().__init__()
self.topk = topk
self.qk_dim = qk_dim
self.scale = qk_scale or qk_dim ** -0.5
self.diff_routing = diff_routing
# TODO: norm layer before/after linear?
self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
# routing activation
self.routing_act = nn.Softmax(dim=-1)
def forward(self, query: Tensor, key: Tensor) -> Tensor:
"""
Args:
q, k: (n, p^2, c) tensor
Return:
r_weight, topk_index: (n, p^2, topk) tensor
"""
if not self.diff_routing:
query, key = query.detach(), key.detach()
query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)
attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
return r_weight, topk_index
class QKVLinear(nn.Module):
def __init__(self, dim, qk_dim, bias=True):
super().__init__()
self.dim = dim
self.qk_dim = qk_dim
self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
def forward(self, x):
q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1)
return q, kv
# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
# return q, k, v
class KVGather(nn.Module):
def __init__(self, mul_weight='none'):
super().__init__()
assert mul_weight in ['none', 'soft', 'hard']
self.mul_weight = mul_weight
def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):
"""
r_idx: (n, p^2, topk) tensor
r_weight: (n, p^2, topk) tensor
kv: (n, p^2, w^2, c_kq+c_v)
Return:
(n, p^2, topk, w^2, c_kq+c_v) tensor
"""
# select kv according to routing index
n, p2, w2, c_kv = kv.size()
topk = r_idx.size(-1)
# print(r_idx.size(), r_weight.size())
# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy
dim=2,
index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv)
)
if self.mul_weight == 'soft':
topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
elif self.mul_weight == 'hard':
raise NotImplementedError('differentiable hard routing TBA')
# else: #'none'
# topk_kv = topk_kv # do nothing
return topk_kv
class BiLevelRoutingAttention(nn.Module):
"""
n_win: number of windows i