注意力机制 -自注意力和位置编码

本文探讨了自注意力在序列编码中的应用,比较了CNN、RNN与自注意力的优缺点,并介绍了如何通过位置编码(包括绝对和相对位置信息)增强模型对序列顺序的理解。重点讲解了MultiHeadAttention和PositionalEncoding的概念及其在深度学习实践中的实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

自注意力和位置编码

在深度学习中,我们经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。想象一下,有了注意力机制之后,我们将词元序列输入注意力池化后,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出

由于查询、键和值来自同一组输入,因此被称为自注意力(self-attention),也被称为内部注意力(intra-attention)。在本节中,我们将使用自注意力进行序列编码,以及如何使用序列的顺序作为补充信息

import math
import torch
from torch import nn
from d2l import torch as d2l

1 - 自注意力

num_hiddens,num_heads = 100,5
attention = d2l.MultiHeadAttention(num_hiddens,num_hiddens,num_hiddens,num_hiddens,num_heads,0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size,num_queries,valid_lens = 2,4,torch.tensor([3,2])
X = torch.ones(batch_size,num_queries,num_hiddens)
attention(X,X,X,valid_lens).shape
torch.Size([2, 4, 100])

2 - 比较卷积神经网络、循环神经网络和自注意力

让我们⽐较下⾯⼏个架构,⽬标都是将由n个词元组成的序列映射到另⼀个⻓度相等的序列,其中的每个输⼊词元或输出词元都由d维向量表⽰。具体来说,我们将⽐较的是卷积神经⽹络、循环神经⽹络和⾃注意⼒这⼏个架构的计算复杂性、顺序操作和最⼤路径⻓度。请注意,顺序操作会妨碍并⾏计算,⽽任意的序列位置组合之间的路径越短,则能更轻松地学习序列中的远距离依赖关系

3 - 位置编码

class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self,num_hiddens,dropout,max_len=1000):
        super(PositionalEncoding,self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1,max_len,num_hiddens))
        X = torch.arange(max_len,dtype=torch.float32).reshape(-1,1)/torch.pow(10000,torch.arange(0,num_hiddens,2,dtype=torch.float32) / num_hiddens)
        self.P[:,:,0::2] = torch.sin(X)
        self.P[:,:,1::2] = torch.cos(X)
        
    def forward(self,X):
        X = X + self.P[:,:X.shape[1],:].to(X.device)
        return self.dropout(X)

在位置嵌入矩阵P中,行代表词元在序列中的位置,列代表位置编码的不同维度。在下面的例子中,我们可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替

encoding_dim,num_steps = 32,60
pos_encoding = PositionalEncoding(encoding_dim,0)
pos_encoding.eval()

X = pos_encoding(torch.zeros((1,num_steps,encoding_dim)))
P = pos_encoding.P[:,:X.shape[1],:]
d2l.plot(torch.arange(num_steps),P[0,:,6:10].T,xlabel='Row (position)',figsize=(6,2.5),legend=["Col %d" % d for d in torch.arange(6,10)])


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0GTxej1f-1663075900559)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209132124216.svg)]

绝对位置信息

为了明白沿着编码维度单调降低的频率于绝对位置信息的关系,让我们打印出0,1…,7的二进制表示形式。正如我们所看到的,每个数字、每两个数字和每四个数字的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替

for i in range(8):
    print(f'{i}的二进制是:{i:>03b}')
0的二进制是:000
1的二进制是:001
2的二进制是:010
3的二进制是:011
4的二进制是:100
5的二进制是:101
6的二进制是:110
7的二进制是:111

在二进制表示中,较高比特位的交替频率低于较低比特位,与下面的热图相似,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间

P = P[0,:,:].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9Ihh6y9U-1663075900559)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209132124217.svg)]

相对位置信息

4 - 小结

  • 在自注意力中,查询、键和值都来自同一组输入
  • 卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢
  • 为了适应序列的顺序信息,我们可以通过在输入表示中添加位置编码,来注入绝对的或相对的位置信息
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值