Multi-Head Attention详解

在这里插入图片描述
文中大部分内容以及图片来自:https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a

当使用 multi-head attention 时,通常d_key = d_value =(d_model / n_heads),其中n_heads是头的数量。研究人员称,通常使用平行注意层代替全尺寸性,因为该模型能够“关注来自不同位置的不同表示子空间的信息”。

通过线性层传递输入

计算注意力的第一步是获得Q、K和V张量;它们分别是查询张量、键张量和值张量。它们是通过采用位置编码的嵌入来计算的,它将被记为X,同时将张量传递给三个线性层,它们被记为Wq, Wk和Wv。这可以从上面的详细图像中看到。

  • Q = XWq
  • K = XWk
  • V = XWv
    为了理解乘法是如何发生的,最好将每个组件分解成这个形状:
  • X的大小为(batch_size, seq_length, d_model)。例如,一批32个序列的长度为10,嵌入为512,其形状为(32,10,512)。
  • Wq,Wk和Wv的大小为(d_model,d_model)。按照上面的示例,它们的形状为(512,512)。

因此,可以更好地理解乘法的输出。每个重量矩阵同时在批处理中 broadcast 每个序列,以创建Q,K和V张量。

  • Q = XWq | (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)
  • K = XWk | (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)
  • V = XWv | (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)

下面的图片显示了Q, K和V是如何出现的。每个紫色盒子代表一个序列,每个橙色盒子是序列中的一个 token 或单词。灰色椭圆表示每个token 的嵌入。
在这里插入图片描述

下面的代码加载了Positional Encoding和Embeddings类。

# convert the sequences to integers
sequences = ["I wonder what will come next!",
             "This is a basic example paragraph.",
             "Hello, what is a basic split?"]

# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]

# index the sequences 
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]

# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()

# vocab size
vocab_size = len(stoi)

# embedding dimensions
d_model = 8

# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)

# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)

# embed the sequence
embeddings = lut(tensor_sequences)

# positionally encode the sequences
X = pe(embeddings)
tensor([[[-3.45, -1.34,  4.12, -3.33, -0.81, -1.93, -0.28,  8.25],
         [ 7.36, -1.09,  2.32,  1.52,  3.50,  1.42,  0.46, -0.95],
         [-2.26,  0.53, -1.02,  1.49, -3.97, -2.19,  2.86, -0.59],
         [-3.87, -2.02,  1.46,  6.78,  0.88,  1.08, -2.97,  1.45],
         [ 1.12, -2.09,  1.19,  3.87, -0.00,  3.73, -0.88,  1.12],
         [-0.35, -0.02,  3.98, -0.20,  7.05,  1.55,  0.00, -0.83]],

        [[-4.27,  0.17, -2.08,  0.94, -6.35,  1.99,  5.23,  5.18],
         [-0.00, -5.05, -7.19,  3.27,  1.49, -7.11, -0.59,  0.52],
         [ 0.54, -2.33, -1.10, -2.02, -0.88, -3.15,  0.38,  5.26],
         [ 0.87, -2.98,  2.67,  3.32,  1.16,  0.00,  1.74,  5.28],
         [-5.58, -2.09,  0.96, -2.05, -4.23,  2.11, -0.00,  0.61],
         [ 6.39,  2.15, -2.78,  2.45,  0.30,  1.58,  2.12,  3.20]],

        [[ 4.51, -1.22,  2.04,  3.48,  1.63,  3.42,  1.21,  2.33],
         [-2.34,  0.00, -1.13,  1.51, -3.99, -2.19,  2.86, -0.59],
         [-4.65, -6.12, -7.08,  3.26,  1.50, -7.11, -0.59,  0.52],
         [-0.32, -2.97, -0.99, -2.05, -0.87, -0.00,  0.39,  5.26],
         [-0.12, -2.61,  2.77,  3.28,  1.17,  0.00,  1.74,  5.28],
         [-5.64,  0.49,  2.32, -0.00, -0.44,  4.06,  3.33,  3.11]]],
       grad_fn=<MulBackward0>)

此时,嵌入序列X的形状为(3,6,8)。有3个序列,包含6个标记,具有8维嵌入。

Wq、Wk和Wv的线性层可以使用nn.Linear(d_model, d_model)来创建。这将创建一个(8,8)矩阵,该矩阵将在跨每个序列的乘法期间广播。

Wq = nn.Linear(d_model, d_model)          # query weights (8,8)
Wk = nn.Linear(d_model, d_model)          # key weights   (8,8)
Wv = nn.Linear(d_model, d_model)          # value weights (8,8)

Wq.state_dict()['weight']
tensor([[ 0.19,  0.34, -0.12, -0.22,  0.26, -0.06,  0.12, -0.28],
        [ 0.09,  0.22,  0.32,  0.11,  0.21,  0.03, -0.35,  0.31],
        [-0.34, -0.21, -0.11,  0.34, -0.28,  0.03,  0.26, -0.22],
        [-0.35,  0.11,  0.17,  0.21, -0.19, -0.29,  0.22,  0.20],
        [ 0.19,  0.04, -0.07, -0.02,  0.01, -0.20,  0.30, -0.19],
        [ 0.23,  0.15,  0.22,  0.26,  0.17,  0.16,  0.23,  0.18],
        [ 0.01,  0.06, -0.31,  0.19,  0.22,  0.08,  0.15, -0.04],
        [-0.11,  0.24, -0.20,  0.26, -0.01, -0.14,  0.29, -0.32]])

Wq的权重如上图所示。Wk和Wv形状相同,但权重不同。当X穿过每一个线性层时,它保持它的形状,但是现在Q, K和V已经被权值转换成唯一的张量。

Q = Wq(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)
K = Wk(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)
V = Wv(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)

Q
tensor([
        # sequence 0
        [[-3.13,  2.71, -2.07,  3.54, -2.25, -0.26, -2.80, -4.31],
         [ 1.70,  1.63, -2.90, -2.90,  1.15,  3.01,  0.49, -1.14],
         [-0.69, -2.38,  3.00,  3.09,  0.97, -0.98, -0.10,  2.16],
         [-3.52,  2.08,  2.36,  2.16, -2.48,  0.58,  0.33, -0.26],
         [-1.99,  1.18,  0.64, -0.45, -1.32,  1.61,  0.28, -1.18],
         [ 1.66,  2.46, -2.39, -0.97, -0.47,  1.83,  0.36, -1.06]],
        # sequence 1
        [[-3.13, -2.43,  3.85,  4.34, -0.60, -0.03,  0.04,  0.62],
         [-0.82, -2.67,  1.82,  0.89,  1.30, -2.65,  2.01,  1.56],
         [-1.42,  0.11, -1.40,  1.36, -0.21, -0.87, -0.88, -2.24],
         [-2.70,  1.88, -0.10,  1.95, -0.75,  2.54, -0.14, -1.91],
         [-2.67, -1.58,  2.46,  1.93, -1.78, -2.44, -1.76, -1.23],
         [ 1.23,  0.78, -1.93, -1.12,  1.07,  2.98,  1.82,  0.18]],
        # sequence 2
        [[-0.71,  1.90, -1.12, -0.97, -0.23,  3.54,  0.65, -1.39],
         [-0.87, -2.54,  3.16,  3.04,  0.94, -1.10, -0.10,  2.07],
         [-2.06, -3.30,  3.63,  2.39,  0.38, -3.87,  1.86,  1.79],
         [-2.00,  0.02, -0.90,  0.68, -1.03, -0.63, -0.70, -2.77],
         [-2.76,  1.90,  0.14,  2.34, -0.93,  2.38, -0.17, -1.75],
         [-1.82,  0.15,  1.79,  2.87, -1.65,  0.97, -0.21, -0.54]]],
       grad_fn=<ViewBackward0>)

Q K和V都是这个形状。和前面一样,每个矩阵是一个序列,每一行都是由嵌入表示的 token。

把Q, K, V分成多头

通过创建Q,K和V张量,现在可以通过将D_Model的视图更改为 (n_heads, d_key) 将它们分为各自的头部。N_heads可以是任意数字,但是使用较大的嵌入时,通常要执行8、10或12。请记住,d_key = (d_model / n_heads)。
在这里插入图片描述
在前面的图像中,每个 token 在单个维度中包含d_model嵌入。现在,这个维度被分成行和列来创建一个矩阵;每行是一个包含键的头。这可以从上面的图片中看到。

每个张量的形状变成:

  • (batch_size, seq_length, d_model) → (batch_size, seq_length, n_heads, d_key)

假设示例中选择了四个heads,则(3,6,8)张量将被分成(3,6,4,2)张量,其中有3个序列,每个序列中有6个tokens,每个标记中有4个heads,每个正面中有2个元素。

这可以通过view来实现,它可以用来添加和设置每个维度的大小。由于每个示例的批大小或序列数量是相同的,因此可以设置批大小。同样,在每个张量中,头的数量和键的数量应该是恒定的。-1可用于表示剩余值,即序列长度。

batch_size = Q.size(0)   
n_heads = 4
d_key = d_model//n_heads # 8/4 = 2

# query tensor | -1 = query_length | (3, 6, 8) -> (3, 6, 4, 2)
Q = Q.view(batch_size, -1, n_heads, d_key)

# value tensor | -1 = key_length | (3, 6, 8) -> (3, 6, 4, 2) 
K = K.view(batch_size, -1, n_heads
Multi-head attention是一种注意力机制,它在Transformer模型中被引入。它可以看作是多个self-attention的组合,类似于CNN中的多核。不同于循环计算每个头,multi-head attention使用矩阵乘法来实现。它的计算流程可以通过转置和重塑来完成。使用多头注意力机制可以使模型同时关注来自不同表示子空间和不同位置的信息,从而提高模型的表达能力。理解self-attention的本质实际上就是了解multi-head attention结构。\[1\]\[2\]\[3\] #### 引用[.reference_title] - *1* [自注意力(Self-Attention)与Multi-Head Attention机制详解](https://blog.csdn.net/weixin_60737527/article/details/127141542)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [Multi-Head Attention的讲解](https://blog.csdn.net/qq_41980734/article/details/120842437)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [详解Transformer中Self-Attention以及Multi-Head Attention](https://blog.csdn.net/qq_37541097/article/details/117691873)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值