PyTorch nn.MultiHead() 参数理解

本文详细介绍了PyTorch中nn.MultiheadAttention模块的使用方法和注意事项,包括模型初始化、参数解释、输入输出格式,并通过实例展示了如何避免报错。讨论了embed_dim与num_heads的关系,强调了它们之间的约束条件。此外,还探讨了attn_mask参数的作用,用于在self-attention中屏蔽不可见的位置。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

之前一直是自己实现MultiHead Self-Attention程序,代码段又臭又长。后来发现Pytorch 早已经有API nn.MultiHead()函数,但是使用时我却遇到了很大的麻烦。

首先放上官网说明:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O w h e r e   h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) MultiHead(Q,K,V)=Concat(head_1,…,head_h)W_O\quad where\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) MultiHead(Q,K,V)=Concat(head1,,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV)

# 模型初始化
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
'''
embed_dim – 嵌入向量总长度.

num_heads – 并行的head数目,即同时做多少次不同语义的attention.

dropout – dropout的概率.

bias – 是否添加偏置.默认: True.
'''


# 模型运算
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)
'''
Inputs:

query: (L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension.

key: (S, N, E) , where S is the source sequence length, N is the batch size, E is the embedding dimension.

value: (S, N, E) where S is the source sequence length, N is the batch size, E is the embedding dimension.

key_padding_mask: (N, S)(N,S) , ByteTensor, where N is the batch size, S is the source sequence length.

attn_mask: 2D mask (L, S)(L,S) where L is the target sequence length, S is the source sequence length. 3D mask (N*num_heads, L, S)(N∗num_heads,L,S) where N is the batch size, L is the target sequence length, S is the source sequence length.

Outputs:

attn_output: (L, N, E)(L,N,E) where L is the target sequence length, N is the batch size, E is the embedding dimension.

attn_output_weights: (N, L, S)(N,L,S) where N is the batch size, L is the target sequence length, S is the source sequence length.
'''

值得注意的一点是,query,key,value的输入形状一定是 [sequence_size, batch_size, emb_size]

官网例子:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

embed_dim, num_heads参数

但是我执行程序却报错了:

A = torch.arange(1,25).view(4,3,2)
A = A.float()

self_attn = torch.nn.MultiheadAttention(embed_dim=2, num_heads=4, dropout=0.0)
res,weight = self_attn(A,A,A)

报错信息:

    self_attn = torch.nn.MultiheadAttention(embed_dim=2, num_heads=4, dropout=0.0)
  File "E:\Anaconda3\lib\site-packages\torch\nn\modules\activation.py", line 740, in __init__
    assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
AssertionError: embed_dim must be divisible by num_heads

切到源码看不懂,而且我用pycharm 一直ctrl+鼠标进入不了最底层的代码,只有前几层的代码。(求相关领域大佬教教我)

经过自己的尝试,nn.MultiheadAttention(embed_dim, num_heads)中的要满足两点约束:

  • embed_dim == input_dim ,即query,key,value的embedding_size必须等于embed_dim
  • embed_dim%num_heads==0

上面的约束,也就说明了在使用nn.MultiheadAttention(embed_dim, num_heads)时, num_heads不是我们想设多少就设定多少。

我的看法:

nn.MultiheadAttention(embed_dim, num_heads) 中的embed_dim 是输入的embeddingsize,即query输入形状(L, N, E)的E数值,nn.MultiheadAttention 想要实现的是无论head 数目设置成多少,输出的向量大小都是不变的。写完这句话,发现自己表达能力真是弱,自己都看不懂。可以结合下面公式理解。
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O w h e r e   h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) Q ∈ R ( L , N , E ) ,   K ∈ R ( S , N , E ) ,   V ∈ R ( S , N , E ) ,   W i ∈ R ( E , E / h ) , W O ∈ R ( E , E ) MultiHead(Q,K,V)=Concat(head_1,…,head_h)W_O\quad where\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)\\ Q\in R^{(L, N, E)},\ K\in R^{(S, N, E)},\ V\in R^{(S, N, E)},\ W_i\in R^{(E,E/h)},W_O\in R^{(E,E)} MultiHead(Q,K,V)=Concat(head1,,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV)QR(L,N,E), KR(S,N,E), VR(S,N,E), WiR(E,E/h),WOR(E,E)

attn_mask参数

self-attention公式: SA ⁡ ( Q , K , V ) = s o f t m a x ( Q K T d k ) V \operatorname{SA}(Q, K, V)=softmax\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V SA(Q,K,V)=softmax(dk QKT)V

Q K T Q K^{T} QKT生成权重分布,但是在应用中有些位置的权重是不可见的,比如在时间序列中,第t天时,我们并不知道t+1天之后的信息。这时就需要传入attn_mask参数,屏蔽这些不合理的权重。attn_mask要求是booltensor,某个位置true表示掩盖该位置。

import torch
import torch.nn as nn

A = torch.Tensor(5,2,4)
nn.init.xavier_normal_(A)
print(A)

# tensor([[[ 0.3688,  0.0391,  0.2048, -0.0906],
#          [-0.0654,  0.1193, -0.1792,  0.0470]],
#
#         [[ 0.0812, -0.4180, -0.1353, -0.2670],
#          [ 0.0433,  0.1442,  0.1733,  0.0535]],
#
#         [[ 0.2352, -0.3314, -0.0238,  0.4116],
#          [ 0.1062,  0.5122,  0.1572, -0.2991]],
#
#         [[ 0.3381,  0.4004, -0.1936, -0.1553],
#          [-0.0168,  0.5914,  0.7389, -0.1740]],
#
#         [[ 0.0446, -0.1739, -0.2020,  0.2580],
#          [-0.0109,  0.0854,  0.2634, -0.4735]]])

M = nn.MultiheadAttention(embed_dim=4, num_heads=2)
attention_mask = ~torch.tril(torch.ones([A.shape[0],A.shape[0]])).bool()
print(attention_mask)

# tensor([[False,  True,  True,  True,  True],
#         [False, False,  True,  True,  True],
#         [False, False, False,  True,  True],
#         [False, False, False, False,  True],
#         [False, False, False, False, False]])

attn_output, attn_output_weights=M(A,A,A, attn_mask=attention_mask)
print(attention_mask)

# tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#          [0.5067, 0.4933, 0.0000, 0.0000, 0.0000],
#          [0.3350, 0.3276, 0.3374, 0.0000, 0.0000],
#          [0.2523, 0.2511, 0.2549, 0.2417, 0.0000],
#          [0.2004, 0.1962, 0.2039, 0.1981, 0.2013]],
# 
#         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#          [0.5025, 0.4975, 0.0000, 0.0000, 0.0000],
#          [0.3325, 0.3312, 0.3363, 0.0000, 0.0000],
#          [0.2535, 0.2429, 0.2633, 0.2404, 0.0000],
#          [0.2002, 0.1986, 0.2008, 0.1976, 0.2028]]], grad_fn=<DivBackward0>)
### nn.Transformer 模型参数量计算 对于 `nn.Transformer` 模型而言,其主要组成部分包括编码器(Encoder)、解码器(Decoder),以及嵌入层和线性变换层。这些组件共同决定了整个模型的参数数量。 #### 编码器部分 编码器由多个相同的层堆叠而成,每一层包含两个子层:一个多头自注意力机制(Multi-head Self-Attention Mechanism)和一个全连接前馈网络(Feedforward Network)。假设编码器有 $N_e=6$ 层,则每层中的多头自注意力机制会引入如下参数: - 查询矩阵权重 $\text{W}_Q \in R^{d_{model} \times d_k}$, 键矩阵权重 $\text{W}_K \in R^{d_{model} \times d_k}$ 和 值矩阵权重 $\text{W}_V \in R^{d_{model} \times d_v}$ 的总参数数为 $(d_{model}\times d_k+d_{model}\times d_k+d_{model}\times d_v)\times h$, 其中$h=8$ 是头部数目; - 输出投影矩阵 $\text{W}_{O} \in R^{hd_v\times d_{model}}$ 因此,单个多头自注意力模块的参数总数大约为: \[ (3 \times d_{model} \times d_k + hd_v \times d_{model})\times N_e \] 接着是全连接前馈网络 FFN 中有两个线性转换层加上激活函数 ReLU,第一个线性层输入维度$d_{model}=512$ ,输出维度为中间维度 $dim\_feedforward=2048$;第二个线性层则相反。所以FFN贡献了额外约 \[ 2(d_{model} \times dim\_feedforward) \times N_e \] 个可训练参数给编码器。 #### 解码器部分 解码器结构类似于编码器,但在每个解码单元之间还加入了掩蔽多头注意机制 Masked Multihead Attention 及交叉注意力 Cross-attention 来处理目标序列内部依赖性和源-目之间的交互作用。这部分同样具有类似的参数规模,即约为上述编码器参数的一倍左右[^1]。 #### 嵌入层与最终分类器 除了核心架构外,在实际应用时还需要考虑词表大小带来的影响。具体来说,源语言侧需设置大小为 src_vocab_size 的嵌入向量组,而目标端则是 tgt_vocab_size 维度的空间。这两个操作分别增加了 vocab size * embedding dimension 即 \(src\_vocab\_size \times d_{model}\),\(tgt\_vocab\_size \times d_{model}\) 数目的参数。最后通过一个线性层将 decoder output 映射回词汇空间完成预测任务,这又带来了另一个 \(d_{model} \times tgt\_vocab\_size\) 大小的参数集。 综上所述,可以得出 PyTorch 实现的标准 Transformer 架构下的近似参数总量公式为: \[ P = 2[(3 \times d_{model} \times d_k + hd_v \times d_{model})+(2 \times d_{model} \times dim\_feedforward)] \times layers + (src\_vocab\_size+tgt\_vocab\_size+1) \times d_{model} + d_{model} \times tgt\_vocab\_size \] 代入已知数值后得到的结果就是该特定配置下 Transformer 所拥有的大致参数量[^3]。 ```python import math def calculate_parameters(src_vocab_size, tgt_vocab_size, max_len, d_model, num_heads, num_encoder_layers, num_decoder_layers, dim_feedforward): params = 0 # Encoder and Decoder parameters calculation attention_params_per_layer = (3*d_model*(d_model//num_heads)+d_model*num_heads)*(num_encoder_layers+num_decoder_layers) ffn_params_per_layer = 2*d_model*dim_feedforward*(num_encoder_layers+num_decoder_layers) encoder_decoder_shared_params = attention_params_per_layer + ffn_params_per_layer # Embedding layer & final classifier parameters embeddings_and_classifier = (src_vocab_size + tgt_vocab_size)*d_model + d_model*tgt_vocab_size total_params = encoder_decoder_shared_params + embeddings_and_classifier return int(total_params) params_count = calculate_parameters(11000, 12000, 60, 512, 8, 6, 6, 2048) print(f'Total Parameters Count: {params_count}') ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值