前言
之前手搓了Self-Attention,这次来搓一个最基础配置的Multi-Head Self-Attention,关于Self-Attention的构造思路可以看我之前的博客:手搓Self-Attention。
步骤
同样根据大佬的博客,我们分析一下,在图例中:
,
,
,
即
我们之前已经完成了Z的计算,所以只需要将各个Z拼接在一起,然后做线性变换得到最终的输出即可。
实现
1. Self-Attention
class SelfAttention(nn.Module):
def __init__(self, d_m, d_k) -> None:
super(SelfAttention, self).__init__()
self.d_m = d_m
self.d_k = d_k
self.Wq = nn.Linear(in_features=self.d_m, out_features=self.d_k)
self.Wk = nn.Linear(in_features=self.d_m, out_features=self.d_k)
self.Wv = nn.Linear(in_features=self.d_m, out_features=self.d_k)
def forward(self, x):
q,k,v = self.Wq(x), self.Wk(x),self.Wv(x)
score = pt.einsum('nci, ncj -> nc',q,k)
score /= pt.sqrt(pt.tensor(self.d_k))
score = nn.functional.softmax(score, dim=-1)
out = v * score[:,:,None]
return out
2. Multi-Head Self-Attention
这里我们遵从nn.MultiheadAttention,模型的维度必须整除头数(图示中头数是8大于嵌入维度是4,有些奇怪,是哪里出了问题,大佬可以评论一下)。
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads) -> None:
super(MultiHeadSelfAttention, self).__init__()
assert embed_dim % num_heads == 0, 'embed_dim must be divided by num_heads'
self.heads = nn.ModuleList([SelfAttention(d_m=embed_dim, d_k=embed_dim//num_heads)
for _ in range(num_heads)])
self.Wo = nn.Linear(in_features=embed_dim, out_features=embed_dim)
def forward(self, x):
Z = pt.cat([head(x) for head in self.heads], dim=-1)
return self.Wo(Z)
''' example '''
a = pt.randn(2, 3, 10)
m = MultiHeadSelfAttention(10, 5)
b = m(a)
b
效果
左边是单头,右边是双头,效果略好于单头。
结尾
有问题欢迎在评论区讨论!