零基础-动手学深度学习-10.5. 多头注意力

10.5.1. 模型

import math
import torch
from torch import nn
from d2l import torch as d2l

10.5.2. 实现

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)#你有多少个维度,就把你的特征拿出来切开成多少个heads,这里的-1是自动计算了

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)#维度顺序调转之后把两个维度合在一起

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

transpose_qkv 把“多头”拆出来并拉成二维,便于并行计算
transpose_output 再把计算结果还原成原来的三维形状。

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

输出:MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

10.5.3. 小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

  • 基于适当的张量操作,可以实现多头注意力的并行计算。

NumPy中,ndarray对象内容可通过索引切片来访问修改,索引方法类型包括字段访问、基本切片高级索引,存取元素的方法Python序列型数据的索引切片方法相同[^1]。 ### 一维数组索引切片 以下是一些一维数组索引切片的示例代码: ```python import numpy as np # 创建一维数组 a = np.arange(0, 20) # 访问单个元素 print(a[4]) # 切片操作 print(a[4:10]) ``` 在上述代码中,`a[4]` 用于访问数组 `a` 中索引为 4 的元素,`a[4:10]` 用于获取数组 `a` 中索引从 4 到 9 的元素组成的子数组[^3]。 ### 二维数组的索引切片 以下是二维数组索引切片的示例代码: ```python import numpy as np # 创建二维数组 b = np.arange(0, 16).reshape(4, 4) # 访问某一行 print(b[2]) # 二次索引访问单个元素 print(b[2][1]) # 另一种方式访问单个元素 print(b[2, 1]) # 切片得到二维子数组 print(b[1:3]) # 更复杂的切片 print(b[0:2, 1:]) ``` 在上述代码中,`b[2]` 用于获取二维数组 `b` 的第 3 行(索引从 0 开始),`b[2][1]` `b[2, 1]` 都用于获取第 3 行第 2 列的元素,`b[1:3]` 用于获取第 2 行第 3 行组成的二维子数组,`b[0:2, 1:]` 用于获取第 1 行第 2 行中从第 2 列开始到最后一列的元素组成的二维子数组[^3]。 ### 三维数组的索引切片 以下是三维数组索引切片的示例代码: ```python import numpy as np # 创建三维数组 c = np.arange(0, 12).reshape(3, 2, 2) # 访问下一个维度的元素 print(c[2]) # 进一步访问元素 print(c[2][0]) # 访问单个值 print(c[2][0][1]) ``` 在上述代码中,`c[2]` 用于获取三维数组 `c` 下一个维度的第 3 个元素(一个二维数组),`c[2][0]` 用于获取该二维数组的第 1 行(一个一维数组),`c[2][0][1]` 用于获取该一维数组索引为 1 的元素[^3]。 ### 多维数组的索引提取 多维数组同样适用上述索引提取方法,例如: ```python import numpy as np a = np.array([[1, 2, 3], [3, 4, 5], [4, 5, 6]]) # 从某个索引处开始切割 print('从数组索引 a[1:] 处开始切割') print(a[1:]) ``` 在上述代码中,`a[1:]` 用于获取从第 2 行开始到最后一行的元素组成的二维子数组[^4]。 ### 切片索引的视图特性 在NumPy中,切片索引操作得到的是原数组的一个视图,进行任何切片或者索引操作并不会产生一个新的数组,都是在原数组的基础上操作。示例如下: ```python import numpy as np arr = np.array([1, 2, 3, 4, 5]) view = arr[1:4] view[0] = 20 print(arr) ``` 在上述代码中,`view` 是 `arr` 的一个视图,修改 `view` 中的元素会影响到原数组 `arr`,最终输出 `[ 1 20 3 4 5]`[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值