目录
由于李沐在注意力部分的讲解有些难以理解,本人强烈推荐:
- 李宏毅老师的3h速通讲解自注意力机制+transformer,讲解通俗易懂,令本人拍案叫绝(强烈推荐!台大李宏毅自注意力机制和Transformer详解!_哔哩哔哩_bilibili);
- 本文会在李沐的讲解的基础上,补充李宏毅的讲解,大大降低学习曲线的陡峭程度
- 接下来的文章只讲重点,以最快的速度直奔Transformer
- 课程全部代码(pytorch版)已上传到附件
- 本章节为原书第10章(注意力机制),共分为7节,本篇是第1节+6节:
- 10.1 注意力提示
- 10.6 自注意力 & 位置编码
- 本文的代码位置为:
- chapter_attention-mechanisms/attention-cues.ipynb
- chapter_attention-mechanisms/self-attention-and-positional-encoding.ipynb
- 本文的视频链接:
使用Transformer,可以捕获和学习整个句子的上下文关系
1. 注意力提示
1.1 两种注意力:
- 非自主性注意力:简单 → 偏感官本能 → 可简单使用(可学习)参数化的全连接(mlp),甚至非参的max pooling (最大池化) / avg pooling (平均池化) 来实现
- 自主性注意力:查询 (自主性提示) → 在自注意力里,以下图为例,对于每个传入的特征,会学习与书本比较近的所有输入特征,这些特征在向量空间里离得比较近
1.2 非参注意力池化层
- 给定数据(xi,yi),就是i个键值对:
- 之前的词表字典,就算是一种环境;
- 在上个例子中,杯子,书也算是环境;
- f(x)传入的x就是查询:
- 平均池化相当于把所有环境中的值平均,并返回
- 核回归是非参的(没有可学习参数),类似于 k-NN (k-近邻),这和现代的注意力池化层方法很相近
- 核回归的K,可选择高斯核(Gaussian kernel):
- exp把输出都变为大于零的数
- 把这个式子K(u)作为注意力权重α,代入平均池化的公式,等于做了softmax:
- 给定查询x(查询)时,先和每个 xi(键) 算一下距离,距离近的注意力权重大
- 如果一个键 xi(键) 越是接近给定的查询x(查询), 那么分配给这个键对应值 yi(值) 的注意力权重就会越大, 也就“获得了更多的注意力”
如何理解距离(以词嵌入为例)
- 如下图所示,每一个输入都先通过嵌入层(可学习的线性层),转为一个稠密的向量(有元素之间的距离),远好于稀疏的独热编码
- (图片来自李宏毅-自注意力机制)
- 进一步理解词嵌入和Embedding层
1.3 参数化的注意力池化层
- 在之前(非参)的基础上引入可以学习的 w
- 类似于做矩阵乘法,类似于全连接层,以分配(可学习的)不同权重
2. 自注意力
在自注意力部分,会使用李宏毅的讲解作为引言:
2.1 为什么需要自注意力
但是序列的长度是有长有短的,输入给模型的序列的长度, 每次可能都不一样。如果要开一个窗口把整个序列盖住,可能要统计一下训练数据,看看训 练数据里面最长序列的长度。接着开一个窗口比最长的序列还要长,才可能把整个序列盖住。 但是开一个这么大的窗口,意味着全连接网络需要非常多的参数,可能不只运算量很大,还容易过拟合。如果想要更好地考虑整个输入序列的信息,就要用到自注意力模型。
2.2 自注意力层
自注意力层可以考虑整句话的上下文,可以理解为带权重学习的多层感知机,参数比直接上全连接层要少很多(权重参数矩阵只有三个QKV),计算开销会很大(这些参数会大量重复用来计算)。而且如上图所示,自注意力层可以像线性层那样叠加。
接下来打开线性层,看看它的内部结构,看起来有点像线性层,但拆开来看实际上会复杂一些:
接下来会先介绍b1的产生过程,剩下的以此类推。自注意力的目的是考虑整个序列,但是又不希望把整个序列所有的信息包在一个窗口里面。自注意力想找到整个序列里那些元素与a1相关,学习它们的特征,得到加权值输出为b1。
2.3 注意力打分 (关联程度 & QKV)
接下来是 注意力分数 α 的计算方法,如图6.18所示:
在上图中,每个 W 都是一个可学习参数,是一个权重矩阵。目前有两种主流的注意力分数计算方法,在接下来的内容里面,我们都只用点积这个方法, 这也是目前最常用的方法,也是用在 Transformer 里面的方法。
最后通过softmax,将权重值拉到0和1之间,对于a1,它与每个输入的注意力分数(α'、α'
、α'
、α'
)的和为1。
2.3.1 掩码 softmax 操作
softmax操作用于输出一个概率分布作为注意力权重。 在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。 例如,为了在 机器翻译 中高效处理小批量数据集, 某些文本序列被填充了没有意义的特殊词元。 为了仅将有意义的词元作为值来获取注意力汇聚, 可以指定一个有效序列长度(即词元的个数), 以便在计算softmax时过滤掉超出指定范围的位置。 下面的masked_softmax
函数 实现了这样的掩码softmax操作(masked softmax operation), 其中任何超出有效长度的位置都被掩蔽并置为0。
import math
import torch
from torch import nn
from d2l import torch as d2l
def masked_softmax(X, valid_lens):
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
# X: 3D张量,valid_lens: 1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1) # 直接对X在最后一个维度上做softmax
else:
shape = X.shape # 获取X的形状
if valid_lens.dim() == 1:
# 将1D的valid_lens重复interleave成与X的第二个维度相同大小
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
# 将2D的valid_lens拉平成一维
valid_lens = valid_lens.reshape(-1)
# 将X reshape成(-1, shape[-1]),并根据valid_lens掩蔽序列,超过的部分用-1e6填充
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
# 将X reshape回原来的形状,并在最后一个维度上做softmax
return nn.functional.softmax(X.reshape(shape), dim=-1)
为了[演示此函数是如何工作]的, 考虑由两个2×4矩阵表示的样本, 这两个样本的有效长度分别为2和3。 经过掩码softmax操作,超出有效长度的值都被掩蔽为0。
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
Out:
tensor([[[0.3929, 0.6071, 0.0000, 0.0000], [0.6768, 0.3232, 0.0000, 0.0000]], [[0.2284, 0.4979, 0.2737, 0.0000], [0.2541, 0.4405, 0.3054, 0.0000]]])
同样,也可以使用二维张量,为矩阵样本中的每一行指定有效长度:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
Out:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000], [0.3210, 0.2656, 0.4134, 0.0000]], [[0.4946, 0.5054, 0.0000, 0.0000], [0.2724, 0.3046, 0.1435, 0.2796]]])
2.3.2 缩放点积注意力
使用点积可以得到计算效率更高的评分函数, 但是点积操作要求查询和键具有相同的长度𝑑。 假设查询和键的所有元素都是独立的随机变量, 并且都满足零均值和单位方差, 那么两个向量的点积的均值为0,方差为𝑑。 为确保无论向量长度如何, 点积的方差在不考虑向量长度的情况下仍然是1, 我们再将点积除以,则缩放点积注意力(scaled dot-product attention)评分函数为:
QKV的缩放点积注意力公式是:
下面的缩放点积注意力的实现使用了 暂退 (丢弃-dropout) 法 进行模型正则化。
class DotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout) # 定义 dropout 层
# queries的形状:(batch_size,查询的个数,d)
# keys的形状:(batch_size,“键-值”对的个数,d)
# values的形状:(batch_size,“键-值”对的个数,值的维度)
# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数),有效长度,在seq2seq里有讲
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1] # 获取查询的维度 d
# 计算 queries 和 keys 的点积(bmm),对应公式 A = K转置·Q;(transpose)交换 keys 的最后两个维度
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) # 缩放点积, 其中sqrt:开平方根
self.attention_weights = masked_softmax(scores, valid_lens) # 应用掩码softmax, 由布尔值看哪些要算
# 将 dropout 应用于注意力权重,然后与 values 做 batch 矩阵乘法(bmm)
return torch.bmm(self.dropout(self.attention_weights), values) # 对应公式 O = A·V
为了[演示上述的DotProductAttention
类], 我们自定义一些键、值和有效长度。 对于点积操作,我们令查询的特征维度与键的特征维度大小相同。
# 其中查询、键和值的形状为(批量大小,步数或词元序列长度,特征大小),
# 实际输出为 q:(2,1,20), k:(2,10,2) 和 v:(2,10,4)
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6]) # 两行序列样本的有效长度是2和6(从0开始数)
queries = torch.normal(0, 1, (2, 1, 2)) # 查询q和键k的形状要相同
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)
# 经过点积注意力后,输出的形状为(批量大小,查询的步数,值的维度) = (2, 2, 4)
Out:
tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]], [[10.0000, 11.0000, 12.0000, 13.0000]]])
由于键包含的是相同的元素(keys = torch.ones((2, 10, 2))生成全1的矩阵量), 而这些元素无法通过任何查询进行区分,因此获得了[均匀的注意力权重]。
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries') # 将权重形状做成(1, 1, 查询的步数, 序列长度)
Out:
- 输出解读:输出的带掩码的权重,两个输出的长度符合输入的两个有效长度[2, 6],长度里的每个时间步都分到了均匀的注意力权重。
2.4 从注意力权重到输出
(上边的代码DotProductAttention的输入是qkv,输出就是b1
)从a到v(特征到值)的转换,也需要乘上一个可学习权重参数W。谁的注意力的分数(α)最大,谁的 v 就会主导(dominant)输出的结果。这里讲述了如何从一整个序列得到 b1。同理,可以计算出 b2 到 b4。
2.5 用矩阵理解自注意力运算
在实际深度学习框架的运算中,不会去一个一个算,而是会拼接成矩阵,进行矩阵乘法,并行计算。在上图中先把整个序列拼接起来(每一列是一个样本的向量类似于词嵌入),分别乘以三个权重W,得到QKV(也是拼接起来的)。
以从 a(I) 到 q(Q) 为例,,每个a要得到q,乘的W是相同的。
- I 的形状:(每个a的嵌入维度,序列长度n)= (嵌入维度, n)
- W
的形状(三种W形状相同):(每个q的通道数d,每个a的嵌入维度)= (d, 嵌入维度)
- Q K V 的形状相同: (通道数d,序列长度n) = (d, n)
K和V也类似。接下来是计算注意力分数,先算q1的注意力分数:
每个q都有四个注意力分数,所有的q的注意力分数拼起来,可以拼接成矩阵,之后按列(即针对每个查询向量 q)进行 softmax 操作):
最后就是得到输出b(O)啦:
- I 的形状:(每个a的嵌入维度,序列长度n)= (嵌入维度, n)
- W
的形状(三种W形状相同):(每个q的通道数d,每个a的嵌入维度)= (d, 嵌入维度)
- Q K V 的形状相同: (通道数d,序列长度n) = (d, n)
- K
的形状: (序列长度n,通道数d) = (n, d)
- A 和 A' 的形状:(序列长度n,序列长度n)= (n, n)
- 输出 O 的形状:(序列长度n,通道数d) = (d, n)

2.6 (复杂度) 对比CNN和RNN
2.6.1 对比CNN
考虑一个卷积核大小为𝑘的卷积层。 在后面的章节将提供关于使用卷积神经网络处理序列的更多详细信息。 目前只需要知道的是,由于序列长度是𝑛,输入和输出的通道数量都是𝑑, 所以卷积层的计算复杂度为。 如上所示, 卷积神经网络是分层的,因此为有
个顺序操作, 最大路径长度为
。 例如,𝐱1和𝐱5处于图中卷积核大小为3的双层卷积神经网络的感受野内。
2.6.2 对比RNN
当更新循环神经网络的隐状态时, 𝑑×𝑑权重矩阵和𝑑维隐状态的乘法计算复杂度为。 由于序列长度为𝑛𝑛,因此循环神经网络层的计算复杂度为
。 根据上图, 有O(𝑛)个顺序操作无法并行化,最大路径长度也是O(𝑛)。
一个不同是RNN如果距离太远就难以考虑到,另一个不同是自注意力可以并行计算。
2.6.3 自注意力本身
- I 的形状:(每个a的嵌入维度,序列长度n)= (嵌入维度, n)
- W
的形状(三种W形状相同):(每个q的通道数d,每个a的嵌入维度)= (d, 嵌入维度)
- Q K V 的形状相同: (通道数d,序列长度n) = (d, n)
- K
的形状: (序列长度n,通道数d) = (n, d)
- A 和 A' 的形状:(序列长度n,序列长度n)= (n, n)
- 输出 O 的形状:(序列长度n,通道数d) = (d, n)
在自注意力中,查询Q、键K 和 值V 都是𝑛×𝑑矩阵。 考虑”点-积“注意力,其中𝑑×𝑛矩阵 (Q) 乘以 𝑛×𝑑矩阵 (K) 。之后输出的 𝑛×𝑛矩阵 (注意力矩阵A') 乘以 𝑑×𝑛矩阵 (V)。 因此,自注意力具有
计算复杂性。 正如在上图所示, 每个词元都通过自注意力直接连接到任何其他词元。 因此,有O(1)个顺序操作可以并行计算, 最大路径长度也是O(1)。
总而言之,卷积神经网络和自注意力都拥有并行计算的优势, 而且自注意力的最大路径长度最短。 但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。
3 多头注意力
3.1 为什么需要多头注意力
自注意力有一个进阶的版本——多头自注意力(multi-head self-attention)。多头自注意力的使用是非常广泛的,有一些任务,比如翻译、语音识别,用比较多的头可以得到比较好的结果。至于需要用多少的头,这个又是另外一个超参数,也是需要调的。为什么会需要更多的头呢?在使用自注意力计算相关性的时候,就是用 q 去找相关的 k。但是相关有很多种不同的形式,所以也许可以有多个 q,不同的 q 负责不同种类的相关性,这就是多头注意力。
李沐的讲解:
- 与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的ℎ组不同的 线性投影(linear projections)来变换查询、键和值。
- 然后,这ℎ组变换后的查询、键和值将并行地送到注意力汇聚中。
- 最后,将这ℎ个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。
- 这种设计被称为多头注意力(multihead attention) :cite:
Vaswani.Shazeer.Parmar.ea.2017
。- 对于ℎ个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。
如图 6.29 所示,先把 a 乘上一个矩阵得到 q,接下来再把 q 乘上另外两个矩阵(可学习的W),分别得到代表有两个头,i 代表的是位置,1 跟 2 代表是这个位置的第几个 q,这个问题里面有两种不同的相关性,所以需要产生两种不同的头来找两种不同的相关性。既然 q 有两个,k 也就要有两个,v 也就要有两个。
多头自注意力是单头的拓展版本,核心思想一致。
最后从矩阵乘法来的角度来更好地理解多头注意力:
3.2 代码实现
在实现过程中通常[选择缩放点积注意力作为每一个注意力头]。 为了避免计算代价和参数代价的大幅增长, 我们设定。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为
, 则可以并行计算 ℎ 个头。 在下面的实现中,p
是通过参数
num_hiddens
指定的。
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads # 设置头的数量
self.attention = d2l.DotProductAttention(dropout) # 定义点积注意力机制
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) # 查询线性变换
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) # 键线性变换
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) # 值线性变换
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) # 输出线性变换
def forward(self, queries, keys, values, valid_lens):
# queries,keys,values的形状:
# (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens 的形状:
# (batch_size,)或(batch_size,查询的个数)
# 经过变换后,输出的queries,keys,values 的形状:
# (batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类。
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# output的形状:(batch_size*num_heads,查询的个数,
# num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)
# output_concat的形状:(batch_size,查询的个数,num_hiddens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
为了能够[使多个头并行计算], 上面的MultiHeadAttention
类将使用下面定义的两个转置函数。 具体来说,transpose_output
函数反转了transpose_qkv
函数的操作。
def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状(做成矩阵乘法,避免for循环)"""
# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
# num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
下面使用键和值相同的小例子来[测试]我们编写的MultiHeadAttention
类。 多头注意力输出的形状是(batch_size
,num_queries
,num_hiddens
)。
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
Out:
MultiHeadAttention( (attention): DotProductAttention( (dropout): Dropout(p=0.5, inplace=False) ) (W_q): Linear(in_features=100, out_features=100, bias=False) (W_k): Linear(in_features=100, out_features=100, bias=False) (W_v): Linear(in_features=100, out_features=100, bias=False) (W_o): Linear(in_features=100, out_features=100, bias=False)
- 本人在做相关实验时,经常遇到加入多层自注意力后爆显存的情况,多头注意力非常占显存,而且运算时长会增加不少
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape # 这个例子里,键和值相同,都是Y
Out:
torch.Size([2, 4, 100])
4. 位置编码
4.1 为什么需要位置编码
- 在处理词元序列时,循环神经网络是逐个的重复地处理词元的, 而自注意力则因为并行计算而放弃了顺序操作。
- 对自注意力而言,q1 跟 q4 的距离并没有特别远,1 跟 4 的距离并没有特别远,2 跟 3 的距离也没有特别近。
4.2 位置编码的概念
4.3 位置编码的代码实现

class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__() # 调用父类构造函数
self.dropout = nn.Dropout(dropout) # 设置 dropout 层
# 创建一个形状为 (1, max_len, num_hiddens) 的全零张量,用于存储位置编码
self.P = torch.zeros((1, max_len, num_hiddens))
# 生成位置编码,用代码实现上方的公式
X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / \ # 反斜杠 \ 是Python的续行功能,
# torch.pow(10000, ...) 计算10000的幂,得到一个缩放因子 # 将长表达式分成多行书写,提高代码可读性
torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X) # 偶数维度用正弦函数填充
self.P[:, :, 1::2] = torch.cos(X) # 奇数维度用余弦函数填充
def forward(self, X):
# 将输入 X 和对应长度的位置编码相加,确保在相同设备上
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X) # 对加上位置编码的 X 应用 dropout,然后返回
在位置嵌入矩阵𝐏中, [行代表词元在序列中的位置,列代表位置编码的不同维度]。 从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。 第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。
encoding_dim, num_steps = 32, 60 # 定义嵌入维度和序列长度
pos_encoding = PositionalEncoding(encoding_dim, 0) # 创建位置编码对象,不使用 dropout
pos_encoding.eval() # 设置模型为评估模式
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim))) # 生成位置编码后的张量
P = pos_encoding.P[:, :X.shape[1], :] # 提取位置编码张量,取前 num_steps 个时间步
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)', # 绘制位置编码的第6到第9维
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)]) # 设置图例和图形大小
4.4 绝对位置
为了明白沿着编码维度单调降低的频率与绝对位置信息的关系, 让我们打印出0,1,…,6,7的[二进制表示]形式。 正如所看到的,每个数字、每两个数字和每四个数字上的比特值 在第一个最低位、第二个最低位和第三个最低位上分别交替。
for i in range(8): # 循环从0到7
print(f'{i}的二进制是:{i:>03b}') # 打印i及其二进制表示,二进制格式宽度为3,补0对齐
0的二进制是:000 1的二进制是:001 2的二进制是:010 3的二进制是:011 4的二进制是:100 5的二进制是:101 6的二进制是:110 7的二进制是:111
P = P[0, :, :].unsqueeze(0).unsqueeze(0) # 选择张量P的第一个元素,并在最前面添加两个新维度
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)', # 显示热图,设置x轴标签为“列(编码维度)”
# 设置y轴标签为“行(位置)”,图的尺寸为(3.5, 4),颜色映射为“Blues”
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
4.5 相对位置

5. 小节
- 现在,你已经集齐了学习 Transformer 的所有组件啦(seq2seq,多头自注意力,束搜索,位置编码...)
- 剩下的就是简单的组装环节,现在再来看这张 Transformer 架构图,就不难理解啦
这里的掩码注意力,和seq2seq的掩码类似,是为了防止模型看到未来的序列,因为我们需要模型用过去预测未来
资源推荐:
- pytorch transformer 官方源码:pytorch/torch/nn/modules/transformer.py at v2.6.0 · pytorch/pytorch (github.com)
- pytorch transformer 源码讲解:pytorch1.2 transformer 的调用方法_transformer调用-优快云博客