多头注意力(Multi-Head Attention, MHA)

多头注意力(Multi-Head Attention, MHA)

多头注意力(Multi-Head Attention, MHA) 是 Transformer 模型(如 BERTGPT)中的核心机制,它扩展了缩放点积注意力(Scaled Dot-Product Attention),使模型可以从多个不同的角度关注输入序列的不同部分。这种机制提高了模型的表达能力,使其能够同时捕获不同的语义信息。


1. 为什么需要多头注意力?

在标准的缩放点积注意力(Scaled Dot-Product Attention)中,查询(Query)、键(Key)和值(Value)之间的关系是通过单一的注意力机制计算的。但这种单一的注意力机制存在一定的局限性:

  • 它只能关注输入序列中的一种信息模式(例如,长距离依赖)。
  • 在实际任务(如机器翻译、文本理解)中,不同的单词可能会以不同的方式关注上下文。

为了解决这个问题,多头注意力通过 多个不同的注意力头(heads) 来分别学习不同的注意力权重,并最终将它们组合在一起,从而增强模型的表示能力。


2. 多头注意力的计算过程

多头注意力的核心思想是:

  1. 对输入进行线性变换,将输入序列映射到多个低维查询(Q)、键(K)和值(V) 空间。
  2. 在多个注意力头(Heads)上计算注意力,每个头使用缩放点积注意力(Scaled Dot-Product Attention)。
  3. 将多个头的输出拼接(Concat)并投影回输出维度,以形成最终的多头注意力输出。

数学公式

假设输入张量的维度为 d model d_{\text{model}} dmodel,多头注意力的计算步骤如下:

  1. 输入线性变换
    Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = X W_i^Q, \quad K_i = X W_i^K, \quad V_i = X W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV
    其中:

    • X X X 是输入序列(形状:(batch_size, seq_len, d_model))。
    • W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V WiQ,WiK,WiV 是不同的线性变换矩阵,用于生成查询、键和值(形状: ( d model , d head ) (d_{\text{model}}, d_{\text{head}}) (dmodel,dhead))。
    • d head = d model / h d_{\text{head}} = d_{\text{model}} / h dhead=dmodel/h(其中 h h h 是注意力头的数量)。
  2. 计算缩放点积注意力(Scaled Dot-Product Attention)
    Attention ( Q i , K i , V i ) = softmax ( Q i K i T d head ) V i \text{Attention}(Q_i, K_i, V_i) = \text{softmax} \left(\frac{Q_i K_i^T}{\sqrt{d_{\text{head}}}} \right) V_i Attention(Qi,Ki,Vi)=softmax(dhead QiKiT)Vi
    这个步骤与标准的缩放点积注意力相同,只是每个注意力头都独立计算自己的注意力分数。

  3. 拼接(Concat)多个头的输出
    MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat} (\text{head}_1, \text{head}_2, \dots, \text{head}_h) W^O MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO
    其中:

    • Concat 操作将所有头的输出沿最后一个维度拼接(形状变回 (batch_size, seq_len, d_model))。
    • W O W^O WO 是投影矩阵,用于将拼接后的张量映射回原始维度。

3. PyTorch 代码实现

PyTorch 提供了 torch.nn.MultiheadAttention 模块,可以直接使用多头注意力。但我们先来看手动实现的多头注意力:

手动实现多头注意力

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除!"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads  # 每个头的维度
        
        # 线性变换矩阵
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V):
        """
        计算缩放点积注意力
        """
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_head, dtype=torch.float32))
        attention_weights = F.softmax(scores, dim=-1)
        return torch.matmul(attention_weights, V), attention_weights

    def forward(self, Q, K, V):
        batch_size = Q.shape[0]

        # 线性变换
        Q = self.W_q(Q)  # (batch_size, seq_len, d_model)
        K = self.W_k(K)
        V = self.W_v(V)

        # 拆分成多个头 (batch_size, seq_len, num_heads, d_head)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)

        # 计算注意力
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V)

        # 重新组合多头 (batch_size, seq_len, d_model)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # 最终的线性变换
        return self.W_o(attention_output), attention_weights

# 测试模型
batch_size = 2
seq_len = 5
d_model = 16
num_heads = 4

multihead_attention = MultiHeadAttention(d_model, num_heads)
X = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = multihead_attention(X, X, X)
print("Output shape:", output.shape)
print("Attention Weights shape:", attn_weights.shape)

4. PyTorch 官方 torch.nn.MultiheadAttention

在实际项目中,可以使用 torch.nn.MultiheadAttention 来实现多头注意力:

import torch
import torch.nn as nn

# 定义多头注意力
d_model = 16
num_heads = 4
multihead_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)

# 创建随机输入 (batch_size, seq_len, d_model)
X = torch.randn(2, 5, 16)

# 计算注意力
output, attn_weights = multihead_attn(X, X, X)

print("Output shape:", output.shape)
print("Attention Weights shape:", attn_weights.shape)

5. 总结

  • 多头注意力(Multi-Head Attention) 通过多个注意力头捕获不同的语义信息,使得模型可以同时关注输入序列的不同部分。
  • 计算过程
    1. 线性变换得到多个查询(Q)、键(K)、值(V)。
    2. 计算多个缩放点积注意力(Scaled Dot-Product Attention)。
    3. 拼接所有注意力头的输出,并通过一个线性变换映射回原始维度。
  • PyTorch 提供了官方实现,可以直接使用 torch.nn.MultiheadAttention 进行调用。

多头注意力是 Transformer 的核心部分,使其能够高效建模长序列数据,是当前 NLP 和 CV 任务的关键技术之一。

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值