讲解版:
# 自注意力机制函数attention 实现思路分析
# attention(query, key, value, mask=None, dropout=None)
# 1 求查询张量特征尺寸大小 d_k
# 2 求查询张量q的权重分布socres q@k^T /math.sqrt(d_k)
# 形状[2,4,512] @ [2,512,4] --->[2,4,4]
# 3 是否对权重分布scores进行 scores.masked_fill(mask == 0, -1e9)
# 4 求查询张量q的权重分布 p_attn F.softmax()
# 5 是否对p_attn进行dropout if dropout is not None:
# 6 求查询张量q的注意力结果表示 [2,4,4]@[2,4,512] --->[2,4,512]
# 7 返回q的注意力结果表示 q的权重分布
def attention(query, key, value, mask=None, dropout=None):
# query, key, value:代表注意力的三个输入张量
# mask:代表掩码张量
# dropout:传入的dropout实例化对象
# 1 求查询张量特征尺寸大小
d_k = query.size()[-1]
# 2 求查询张量q的权重分布socres q@k^T /math.sqrt(d_k)
# [2,4,512] @ [2,512,4] --->[2,4,4]
scores = torch.matmul(query, key.transpose(-2, -1) ) / math.sqrt(d_k)
# 3 是否对权重分布scores 进行 masked_fill
if mask is not None:
# 根据mask矩阵0的位置 对sorces矩阵对应位置进行掩码
scores = scores.masked_fill(mask == 0, -1e9)
# 4 求查询张量q的权重分布 softmax
p_attn = F.softmax(scores, dim=-1)
# 5 是否对p_attn进行dropout
if dropout is not None:
p_attn = dropout(p_attn)
# 返回 查询张量q的注意力结果表示 bmm-matmul运算, 注意力查询张量q的权重分布p_attn
# [2,4,4]*[2,4,512] --->[2,4,512]
return torch.matmul(p_attn, value), p_attn
# 多头注意力机制类 MultiHeadedAttention 实现思路分析
# 1 init函数 (self, head, e