手撕multi-head self attention 代码

在深度学习和自然语言处理领域,多头自注意力(Multi-Head Self-Attention)机制是Transformer模型中的核心组件之一。它允许模型在处理序列数据时,能够同时关注序列中的不同位置,从而捕获到丰富的上下文信息。下面,我们将详细解析多头自注意力机制的实现代码。

一、概述

多头自注意力机制的核心思想是将输入序列进行多次线性变换,然后分别计算自注意力得分,最后将所有头的输出进行拼接,并通过一个线性层得到最终的输出。这样做的好处是可以让模型从不同的子空间学习到不同的注意力信息,提高模型的表达能力。

二、代码实现

以下是一个简化版的多头自注意力机制的PyTorch实现,如果有不足之处,感谢指出!!!!:

import torch
import torch.nn as nn
import math

class MultiHeadSelfAttention(nn.Module):
    """
    多头注意力模块,用于实现transformer模型中的注意力机制。
    
    参数:
        model_dim: 模型维度,即输入和输出的向量维度。
        num_heads: 注意力头的数量。
        dropout_rate: Dropout率,防止模型过拟合,默认为0.1。
    """
    def __init__(self, model_dim, num_heads, dropout_rate=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        assert model_dim % num_heads == 0, "model_dim 必须能整除注意力头的数量。"
      
### 实现 Multi-head Attention 机制 #### 多头注意力机制概述 多头注意力(Multi-head Attention)允许模型在不同表示子空间中并行地关注输入的不同部分。这种机制显著增强了模型捕捉复杂模式的能力[^1]。 #### 关键组件解析 为了构建一个多头注意力模块,需要定义三个主要矩阵:Query(Q)、Key(K)和Value(V)。这些向量通过线性变换得到,并作为后续计算的基础。具体来说: - **Query (Q)** 表示查询向量; - **Key (K)** 对应于记忆库中的条目; - **Value (V)** 则保存了与Keys关联的信息内容; 对于每一个头部,上述三者都会被映射到较低维度的空间内以便处理。之后再将各个头部的结果拼接起来并通过另一个投影层输出最终结果。 #### Python代码实现 下面是一个简单的基于PyTorch框架的多头自注意力建模实例: ```python import torch import torch.nn as nn class ScaledDotProductAttention(nn.Module): """Scaled dot-product attention mechanism.""" def __init__(self, d_k): super(ScaledDotProductAttention, self).__init__() self.d_k = d_k def forward(self, Q, K, V, attn_mask=None): scores = torch.matmul(Q, K.transpose(-1, -2)) / (self.d_k ** 0.5) if attn_mask is not None: scores.masked_fill_(attn_mask, -1e9) attentions = nn.Softmax(dim=-1)(scores) context = torch.matmul(attentions, V) return context, attentions class MultiHeadAttention(nn.Module): def __init__(self, d_model=512, num_heads=8): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0 self.num_heads = num_heads self.d_k = d_model // num_heads self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.linear = nn.Linear(d_model, d_model) self.layer_norm = nn.LayerNorm(d_model) def split_heads(self, x, batch_size): # Split the last dimension into (num_heads * depth). x = x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) return x def combine_heads(self, x): batch_size = x.size()[0] x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) return x def forward(self, Q, K, V, mask=None): residual = Q.clone() batch_size = Q.size(0) Q = self.split_heads(self.W_Q(Q), batch_size) K = self.split_heads(self.W_K(K), batch_size) V = self.split_heads(self.W_V(V), batch_size) context, _ = ScaledDotProductAttention(self.d_k)(Q, K, V, mask) output = self.combine_heads(context) result = self.linear(output) return self.layer_norm(result + residual) if __name__ == "__main__": mha = MultiHeadAttention() query = key = value = torch.rand((64, 32, 512)) out = mha(query, key, value) print(out.shape) ``` 此段代码实现了标准的多头自我注意机制,其中包含了缩放点乘法注意力函数以及必要的预处理步骤如分割和重组heads的操作。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

心若成风、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值