今天看到一篇这个《再谈attention》机制,加深了对attention的理解。
了解查询query,键key,值value。
基本过程是query和键在注意力汇聚层中经过计算,得到了对应的注意力权重,再通过这个权重,对应键的值在输出就会占不同比例,从而实现对值的选择倾向。
根据注意力的计算过程可以分为:
- 加性注意力
- 缩放点积注意力 (transformer主要用到的)
注意力计算过程中,会首先根据键和查询进行注意力计算,得到对应注意力分数,理解为键和查询的相关度,越相关则对应分数越高。再经过softmax映射为0,1之间,得到注意力权重。权重和值相乘得到值得加权和。
查询、键、值都来自同一个地方,称为自注意力self-attention。
自注意力本质是计算序列中每个元素与序列中其他元素之间的相似性。从而可以融合全局信息。
import torch.nn as nn
class AdditiveAttention(nn.Module):
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
# 转为num_hiddens
self.w_k = nn.Linear(key_size, num_hiddens, bias=False)
self.w_q = nn.Linear(query_size, num_hiddens, bias=False)
# 计算值
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
# 防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.w_q(queries), self.w_k(keys)
features = queries.unsqueeze(2) = keys.unsqueeze(1)
features = torch.tanh(features)
socres = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
7054

被折叠的 条评论
为什么被折叠?



