目录
缩放点积注意力(Scaled Dot Product Attention)
加性注意力(Additive Attention)
缩放点积注意力(Scaled Dot Product Attention)
上面公式没有体现批量操作,每个矩阵可以添加batch_size维度,例如Q的维度为(batch_size,n,d)。
多头注意力(Multi-head Attention)
PyTorch实现
from torch import nn
import torch
import math
import matplotlib.pyplot as plt
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(6, 6),
cmap='Reds'):
"""
显示矩阵热图.
:param matrices: 4阶张量. (显示的行数,显示的列数,查询的数量,键的数量)
表示有 行数 * 列数 个矩阵.
:param xlabel: x轴标签.
:param ylabel: y轴标签.
:param titles: 每个图的名称,是一个list.
:param figsize: 图大小.
:param cmap: 热图颜色.
:return:
"""
# 显示的行数,显示的列数. fig上的子图数 num_rows * num_cols.
num_rows, num_cols = matrices.shape[0], matrices.shape[1]
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False)
# 绘制第i行第j列个子图.
for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): # 遍历matrices的第一个维度:行数.
for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): # 遍历matrices的第二个维度:列数.
pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap) # 绘制特例图.
ax.set_xlabel(xlabel) # 为每个子图绘制x轴标签.
ax.set_ylabel(ylabel) # 为每个子图绘制y轴标签.
if titles:
ax.set_title(titles[j]) # 为每个子图绘制标题.
fig.colorbar(pcm, ax=ax, shrink=0.6) # 绘制彩色条.
plt.show()
def masked_softmax(X, valid_lens):
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
# X:3D张量,valid_lens:1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
X = X.reshape(-1, shape[-1])
max_len = X.size(1)
mask = torch.arange((max_len), dtype=torch.float32,
device=X.device)[None, :] < valid_lens[:, None]
X[~mask] = -1e6
return nn.functional.softmax(X.reshape(shape), dim=-1)
class AdditiveAttention(nn.Module):
"""加性注意力"""
def __init__(self, query_size, key_size, num_hiddens, dropout, **kwargs):
"""
定义网络结构.
:param query_size: 查询的维度.
:param key_size: 键的维度.
:param num_hiddens: 隐藏层神经元个数.
:param dropout: 暂退概率.
:param kwargs:
"""
super(AdditiveAttention, self).__init__(**kwargs)
# 上一层神经元个数为query_size,当前层神经元个数num_hiddens,
# 权重矩阵维度 W_q.shape = (num_hiddens, query_size)
# 对于输入A * W_q.T = B.
# 维度变化 (batch_size,query_size) * (num_hiddens, query_size).T = (batch_size,num_hiddens)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False) # 注意这里W_k是神经网络的层,不是权重.
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
"""
前馈网络计算.
:param queries: 查询.
:param keys: 键.
:param values: 值.
:param valid_lens: 有效长度.
:return:
"""
# W_q.shape = (h, q).
# W_k.shape = (h, q).
# queries.shape = (batch_size, T1, q).
# keys.shape = (batch_size, T2, k).
# 矩阵计算如下:queries * W_q.T,keys * W_k.T .
# 计算后的queries维度 (batch_size, T1, q) * (h,q).T = (batch_size, T1, h).
# keys的维度 (batch_size, T2, k) * (h,k).T = (batch_size, T2, h).
queries, keys = self.W_q(queries), self.W_k(keys)
# unsqueeze(n)表示第dim=n维度进行扩展,n从0开始.
# queries.shape = (batch_size, T1, h),在dim=2扩展后 shape = (batch_size, T1, 1, h)
# keys.shape = (batch_size, T2, h),在dim=1扩展后,shape = (batch_size, 1, T2, h)
queries = queries.unsqueeze(2)
keys = keys.unsqueeze(1)
features = queries + keys # 广播求和,features.shape = (batch_size, T1, T2, h)
features = torch.tanh(features)
# w_v的形状(1,h),矩阵计算:features * w_v.T .
# 计算后的维度:(batch_size, T1, T2, h) * (1,h).T = (batch_size, T1, T2, 1)
# 最后维度为1,压缩,得到scores.shape = (batch_size, T1, T2).
scores = self.w_v(features).squeeze(-1)
# softmax不会改变维度,attention_weights.shape = (batch_size, T1, T2).
# 如果不掩码,可以使用如下代码直接计算softmax:
# self.attention_weights = torch.softmax(scores, dim=-1)
# 计算结果为:
# tensor([[[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
# 0.1000, 0.1000]],
#
# [[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
# 0.1000, 0.1000]]], grad_fn= < SoftmaxBackward0 >)
# 掩码操作scores.shape = (batch_size, T1, T2),对T1,T2两个维度进行掩码,
# 如果valid_lens是一阶张量,对T1这个维度进行掩码.
# 如果valid_lens是二阶张量,对(T1,T2)这两个维度进掩码.
# 例如valid_lens = [2,6],为一阶张量,对T1进行掩码,T1=0时,保留有效长度为2,T1=1时,保留有效长度为6.
# 掩码前:
# tensor([[[-0.5469, -0.5469, -0.5469, -0.5469, -0.5469, -0.5469, -0.5469,
# -0.5469, -0.5469, -0.5469]],
#
# [[ 0.0430, 0.0430, 0.0430, 0.0430, 0.0430, 0.0430, 0.0430,
# 0.0430, 0.0430, 0.0430]]], grad_fn=<SqueezeBackward1>)
# 掩码后:
# tensor([[[0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# 0.0000, 0.0000]],
#
# [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
# 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)
self.attention_weights = masked_softmax(scores, valid_lens)
# bmm为矩阵乘法的批量操作,两个张量的batch_size要相等.
# attention_weights.shape = (batch_size, T1, T2)
# values.shape = (batch_size, T2, v).
# 首先保持batch_size这个不变,剩下的维度按照矩阵乘法计算,计算后维度为(T1, T2) * (T2, v) = (T1, v)
# 经过bmm计算后,形状shape=(batch_size, T1, v).
return torch.bmm(self.dropout(self.attention_weights), values)
class DotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# queries的形状:(batch_size,查询的个数,d)
# keys的形状:(batch_size,“键-值”对的个数,d)
# values的形状:(batch_size,“键-值”对的个数,值的维度)
# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
def forward(self, queries, keys, values, valid_lens=None):
"""
:param queries: (batch_size,n,d)
:param keys: (batch_size,m,d)
:param values: (batch_size,m,v)
:param valid_lens:
:return:
"""
d = queries.shape[-1]
# (batch_size,n,m)
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
# (batch_size,n,m)
self.attention_weights = masked_softmax(scores, valid_lens)
# (batch_size,n,m) * (batch_size,m,v) = (batch_size,n,v)
return torch.bmm(self.dropout(self.attention_weights), values)
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, query_size, key_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads # 多头注意力,头的个数.
self.attention = DotProductAttention(dropout) # 初始化点乘注意力.
# 输入维度:(batch_size,query_size),输出维度:(batch_size,num_hiddens),
# 权重矩阵维度:(num_hiddens, query_size)
# 例如 A * W.T = B.
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) # 查询query的线性投射.
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) # 键key的线性投射.
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) # 值value的线性投射.
self.W_o = nn.Linear(num_hiddens, value_size, bias=bias) # 多头head连接后的线性投射.
def forward(self, queries, keys, values, valid_lens):
# 为了多头也可以并行计算,对queries、keys、values进行维度变换.
# queries.shape=(batch_size, T1, q).
# W_q.shape = (num_hiddens, q).
# self.W_q(queries)的维度=(batch_size, T1, num_hiddens).
# 变换后queries.shape=(batch_size * h, T1, num_hiddens/h).
queries = transpose_qkv(self.W_q(queries), self.num_heads)
# keys.shape=(batch_size, T2, k).
# W_k.shape = (num_hiddens, k).
# self.W_k(keys)的维度=(batch_size, T2, num_hiddens).
# 变换后 keys.shape=(batch_size * h, T2, num_hiddens/h).
keys = transpose_qkv(self.W_k(keys), self.num_heads)
# values.shape=(batch_size, T2, v).
# W_v.shape = (num_hiddens, v).
# self.W_v(values)的维度=(batch_size, T2, num_hiddens).
# 变换后values.shape=(batch_size * h, T2, num_hiddens/h).
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类.
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# 点乘注意力.
# output的形状:(batch_size * h, T1, num_hiddens/h).
output = self.attention(queries, keys, values, valid_lens)
# 多头并行计算结束后,再将维度变换回来.
# 经过变换后output_concat.shape = (batch_size, T1, num_hiddens).
output_concat = transpose_output(output, self.num_heads)
# output_concat.shape = (batch_size, T1, num_hiddens).
# W_o.shape = (batch_size, v, num_hiddens).
# 经过矩阵计算output_concat * W_o.T
# 最终返回的维度:(batch_size, T1, v)
return self.W_o(output_concat)
# @save
def transpose_qkv(X, num_heads):
"""
为了多注意力头的并行计算而变换形状
:param X: (batch_size, T, d)
:param num_heads: h
:return: X,X.shape = (batch_size * h, T, d/h)
"""
# 如果X的形状为:(batch_size, T, d)
# 变换后X的形状为:(batch_size, T, h, d/h)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 安装permute函数中索引的值,设置X的维度.
# (batch_size, T, h, d/h) 变为:(batch_size, h, T, d/h)
X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size * h, T, d/h)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
"""
逆转transpose_qkv函数的操作.
:param X: X.shape = (batch_size * h, T, d/h)
:param num_heads: h
:return: X,X.shape = (batch_size, T, d)
"""
# 如果X.shape = (batch_size * h, T, d/h).
# 变换后X.shape = (batch_size, h, T, d/h).
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
# 维度互换.
# 变换后X.shape = (batch_size, T, h, d/h).
X = X.permute(0, 2, 1, 3)
# 最终X.shape = (batch_size, T, d).
return X.reshape(X.shape[0], X.shape[1], -1)
def test_additive_attention():
# queries.shape = (batch_size, T1, q),其中T1=1.
# 表示:输入序列的序列长度为T1,也就是由T1个token构成,每个token的特征维度为q,一次输入batch_size.
# keys.shape = (batch_size, T2, k),其中T2=10.
# 表示:输入序列的长度为T2,也就是由T2个token构成,每个token的特征维度为k,一次输入batch_size.
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
print('queries=', queries)
print('keys=', keys)
# values.shape = (batch_size, T2, v),因为keys,values形成对.
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
print('values=', values)
# 初始化加性注意力.
attention = AdditiveAttention(query_size=queries.shape[2],
key_size=keys.shape[2],
num_hiddens=8,
dropout=0.1)
# 注意力结构信息:
val = attention.eval()
print('val=', val)
valid_lens = torch.tensor([2, 6])
# result.shape=(batch_size, T1, v).
attention_val = attention(queries, keys, values, valid_lens)
print('attention_val=', attention_val)
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries', titles=['AdditiveAttention'])
def test_dot_product_attention():
queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
print('queries=', queries)
print('keys=', keys)
# values.shape = (batch_size, T2, v),因为keys,values形成对.
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
print('values=', values)
attention = DotProductAttention(dropout=0.5)
val = attention.eval()
print('val=', val)
valid_lens = torch.tensor([2, 6])
attention_val = attention(queries, keys, values, valid_lens)
print('attention_val=', attention_val)
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries', titles=['DotProductAttention'])
def test_multi_head_attention():
# queries.shape=(batch_size, T1, q).
batch_size, num_queries = 2, 4
query_size = 20 # q=20.
queries = torch.ones((batch_size, num_queries, query_size))
# keys.shape=(batch_size, T2, k).
num_kv_pairs = 6
key_size = 20 # k=20.
keys = torch.ones((batch_size, num_kv_pairs, key_size))
# values.shape=(batch_size, T2, v).
value_size = 30 # v=30.
values = torch.ones((batch_size, num_kv_pairs, value_size))
valid_lens = torch.tensor([3, 2])
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(query_size=query_size,
key_size=key_size,
value_size=value_size,
num_hiddens=num_hiddens,
num_heads=num_heads,
dropout=0.5)
val = attention.eval()
print('val=', val)
# 最终返回的维度:(batch_size, T1, o)
attention_val = attention(queries, keys, values, valid_lens)
# [batch_size, T1, v]
print('attention_val.shape=', attention_val.shape)
if __name__ == '__main__':
# 加性注意力测试.
print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
test_additive_attention()
# 缩放点积注意力.
print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
test_dot_product_attention()
# 多头注意力.
print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
test_multi_head_attention()
程序输出结果:
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
queries= tensor([[[-0.8907, -1.3623, 0.1130, -1.2687, -1.4991, -0.3159, 0.3349,
-0.7298, -1.8507, -2.5918, -0.2319, -1.3374, 0.5941, 0.9423,
-0.5563, 0.4634, -0.1660, -1.4720, -1.6535, 0.2234]],
[[ 0.5545, -1.1212, -1.6501, -0.5919, 1.4009, 0.9828, 0.4245,
-1.4334, 0.1970, -0.4550, -2.1928, 0.1660, -1.1109, -1.5174,
-1.3909, 1.8204, -0.8928, -0.6160, -0.2877, -0.0853]]])
keys= tensor([[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]])
values= tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.],
[32., 33., 34., 35.],
[36., 37., 38., 39.]],
[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.],
[32., 33., 34., 35.],
[36., 37., 38., 39.]]])
val= AdditiveAttention(
(W_q): Linear(in_features=20, out_features=8, bias=False)
(W_k): Linear(in_features=2, out_features=8, bias=False)
(w_v): Linear(in_features=8, out_features=1, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
)
attention_val= tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
queries= tensor([[[ 0.1392, -0.7757]],
[[-1.8050, -0.4379]]])
keys= tensor([[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]])
values= tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.],
[32., 33., 34., 35.],
[36., 37., 38., 39.]],
[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.],
[32., 33., 34., 35.],
[36., 37., 38., 39.]]])
val= DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
attention_val= tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]])
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
val= MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=20, out_features=100, bias=False)
(W_k): Linear(in_features=20, out_features=100, bias=False)
(W_v): Linear(in_features=30, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=30, bias=False)
)
attention_val.shape= torch.Size([2, 4, 30])
可视化如下: