一、本文介绍
作为入门性篇章,这里介绍了MHSA注意力在YOLOv8中的使用。包含MHSA的代码、MHSA的使用方法、以及添加以后的yaml文件及运行记录。
二、MHSA原理分析
MHSA官方论文地址:MHSA注意力机制文章
MHSA注意力机制(多头注意力机制):
相关代码:
MHSA注意力的代码,如下:
class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
super(MHSA, self).__init__()
self.heads = heads
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.pos = pos_emb
if self.pos:
self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]),
requires_grad=True)
self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]),