多头注意力机制 python代码实现

Python3.9

Python3.9

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

上一章我们讲了多头注意力机制的原理,并且使用了“我请李雷帮我订了去北京的机票“这句话进行了模拟推演。本章配上具体的python代码实现,加深理解。
 


import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# ====================== 1. 准备数据 ======================

# 设置中文字体(根据系统安装的字体选择,以下是常见兼容字体)
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]  # 支持中文显示
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示异常的问题(可选)


tokens = ["我", "请", "李雷", "帮", "我", "订了", "去", "北京", "的", "机票"]
N = len(tokens)  # token 数量
embedding_dim = 8
num_heads = 4
head_dim = embedding_dim // num_heads


# 模拟 embedding 层输出 (N, D) 10 8
embeddings = torch.randn(N, embedding_dim) * 0.5



# ====================== 2. 定义线性层参数 ======================
# 8*8
W_q = torch.randn(embedding_dim, embedding_dim) * 0.3
W_k = torch.randn(embedding_dim, embedding_dim) * 0.3
W_v = torch.randn(embedding_dim, embedding_dim) * 0.3
W_o = torch.randn(embedding_dim, embedding_dim) * 0.3


# ====================== 3. 计算 Q / K / V ======================
Q = embeddings @ W_q  # (N, D)
# print(Q)
K = embeddings @ W_k
V = embeddings @ W_v
# print(Q.size(), K.size(), V.size())
# print(Q)

# 拆分多头
def split_heads(x, num_heads):
    """把 (N, D) → (num_heads, N, D/num_heads)"""
    N, D = x.shape
    #  # 第二步:交换维度 0 和 1,得到 (num_heads, N, D//num_heads)
    return x.view(N, num_heads, D // num_heads).transpose(0, 1)

# 10*8 10个token 每个被编码为8维向量 多头注意力 4头 每个8维向量被拆成4个2维向量
Q_heads = split_heads(Q, num_heads)
# torch.Size([4, 10, 2])
# print(Q_heads.shape)
K_heads = split_heads(K, num_heads)
V_heads = split_heads(V, num_heads)



# ====================== 4. 计算每个 Head 的注意力 ======================
attention_outputs = []
attention_weights = []

scale = head_dim ** 0.5

# exit()
for h in range(num_heads):
    q, k, v = Q_heads[h], K_heads[h], V_heads[h]  # (N, head_dim)

    # QK^T → (N, N) 自注意力(Self-Attention)矩阵。
    scores = (q @ k.T) / scale
    # print(scores.shape)
    # print(scores)

    # Softmax
    weights = F.softmax(scores, dim=-1)
    attention_weights.append(weights)

    # 加权求和
    out = weights @ v  # (N, head_dim)
    attention_outputs.append(out)


    print(f"\n--- Head {h+1} ---")
    print("注意力权重 (softmax 后):\n", torch.round(weights, decimals=3))


# ====================== 5. 拼接多头输出 ======================
concat = torch.cat(attention_outputs, dim=-1)  # (N, D)
final_output = concat @ W_o                    # (N, D)

print("\n=== 最终输出 ===\n", torch.round(final_output, decimals=3))


# exit()
# ====================== 6. 可视化注意力 ======================
fig, axes = plt.subplots(1, num_heads, figsize=(4 * num_heads, 3))
for h in range(num_heads):
    ax = axes[h]
    im = ax.imshow(attention_weights[h].detach().numpy(), cmap='Blues', aspect='auto')
    ax.set_title(f"Head {h+1}")
    ax.set_xticks(range(N))
    ax.set_yticks(range(N))
    ax.set_xticklabels(tokens, rotation=45)
    ax.set_yticklabels(tokens)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()

# ====================== 7. 打印每个 token 的 Top-3 注意力 ======================
print("\n=== 每个 token 在各 head 的 Top-3 注意力对象 ===")
for h in range(num_heads):
    weights = attention_weights[h]
    print(f"\n-- Head {h+1} --")
    for i in range(N):
        top3 = torch.topk(weights[i], 3)
        idxs = top3.indices.tolist()
        vals = top3.values.tolist()
        print(f"{tokens[i]} → {[f'{tokens[idx]}({val:.3f})' for idx, val in zip(idxs, vals)]}")

以上代码会输出:

  • 每个头(head)的注意力权重矩阵(数值 + 可视化 heatmap)

  • 每个 token 的前 3 个最关注对象

  • 每个 token 的最终向量输出(融合了上下文语义)


=== 每个 token 在各 head 的 Top-3 注意力对象 ===

从图中我们可以非常直观的看出来每一个token和其他token之间的关注度权重,能够明白模型是如何理解语义信息的,也可以从结果中看出来,头的侧重点不一样,token之间关注的权重会有一些细微的变化,但句子、段落表达出的总体的意思都是大差不差的。都是--我请李雷帮我订了去北京的机票
希望结合代码能够帮助你更好的理解多头注意力机制。

您可能感兴趣的与本文相关的镜像

Python3.9

Python3.9

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

<think>好的,我现在需要帮助用户找到关于注意力机制的Python代码实现。用户已经提供了几个引用,里面包含不同的代码示例和相关描述。首先,我需要理解用户的具体需求,他们可能想了解如何从零开始实现注意力机制,或者如何在神经网络中应用它,比如在自然语言处理或图神经网络中。 首先,我应该回顾用户提供的引用内容。引用[1]提到了自注意力机制的Python代码,引用[2]是关于多头注意力机制的手撕代码,引用[3]展示了一个包含注意力层的模型,引用[4]则涉及图注意力网络(GAT)。用户的问题是如何用Python实现注意力机制,所以需要涵盖这些不同的方面,但可能更关注基础实现。 接下来,我需要确保代码示例符合用户的要求,并且正确使用LaTeX格式。比如,数学公式需要用$...$或$$...$$包裹。例如,注意力权重计算中的Softmax部分,可能需要写成$\text{Softmax}(\frac{QK^T}{\sqrt{d_k}})$。 然后,我需要组织回答结构,先介绍基础注意力机制,再逐步扩展到多头注意力和应用示例。每个部分都应有代码片段和解释,同时引用提供的参考文献。例如,在基础注意力机制部分,可以引用[1]和[3],而多头部分引用[2],GAT部分引用[4]。 还要注意用户提到的系统级指令,比如行内数学表达式使用$...$,独立公式单独成段用$$...$$。例如,注意力得分的计算公式可能需要独立展示。 此外,用户要求生成相关问题,我需要根据内容提出相关的问题,比如比较不同注意力机制的区别,或者讨论性能优化的方法。这些问题需要与用户提供的引用内容相关,比如引用[4]提到的并行计算和稀疏性,可以生成相关问题。 最后,检查所有引用的正确性,确保每个引用标识对应正确的来源,并且代码示例没有语法错误。例如,引用[2]中的代码片段是否有缩进问题,或者变量名称是否正确。同时,确保使用中文回答,避免使用专业术语时没有解释,保持回答清晰易懂。</think>### 注意力机制的Python实现分步解析 #### 一、基础注意力机制实现 注意力机制的核心是通过计算查询向量$Q$和键向量$K$的相似度,加权值向量$V$得到结果: $$ \text{Attention}(Q,K,V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$ ```python import torch import torch.nn.functional as F def attention(Q, K, V, mask=None): d_k = Q.size(-1) scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k)) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) return torch.matmul(attn_weights, V), attn_weights ``` 此代码实现了缩放点积注意力,包含防止梯度爆炸的$\sqrt{d_k}$缩放和可选的掩码机制[^1]。 #### 二、多头注意力实现 多头注意力将输入拆分为多个子空间并行计算,增强模型表达能力[^2]: ```python class MultiHeadAttention(torch.nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.Wq = torch.nn.Linear(embed_dim, embed_dim) self.Wk = torch.nn.Linear(embed_dim, embed_dim) self.Wv = torch.nn.Linear(embed_dim, embed_dim) self.out = torch.nn.Linear(embed_dim, embed_dim) def split_heads(self, x): batch_size = x.size(0) return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) def forward(self, x): Q = self.split_heads(self.Wq(x)) K = self.split_heads(self.Wk(x)) V = self.split_heads(self.Wv(x)) attn_output, _ = attention(Q, K, V) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(x.size(0), -1, self.num_heads * self.head_dim) return self.out(attn_output) ``` #### 三、应用示例(Seq2Seq模型) ```python class AttentionDecoder(torch.nn.Module): def __init__(self, output_dim, embed_dim, hidden_dim): super().__init__() self.attention = MultiHeadAttention(embed_dim, 8) self.lstm = torch.nn.LSTM(embed_dim*2, hidden_dim) self.fc_out = torch.nn.Linear(hidden_dim, output_dim) def forward(self, encoder_outputs, hidden): query = hidden[0].permute(1, 0, 2) context, _ = self.attention(query, encoder_outputs, encoder_outputs) lstm_input = torch.cat([query, context], dim=2) output, hidden = self.lstm(lstm_input, hidden) prediction = self.fc_out(output.squeeze(1)) return prediction, hidden ``` 该结构展示了注意力机制在序列到序列模型中的应用,通过关注编码器输出的相关部分增强解码效果[^3]。 #### 四、图注意力网络(GAT)实现 ```python class GATLayer(torch.nn.Module): def __init__(self, in_features, out_features, alpha=0.2): super().__init__() self.W = torch.nn.Linear(in_features, out_features) self.a = torch.nn.Linear(2*out_features, 1) self.leakyrelu = torch.nn.LeakyReLU(alpha) def edge_attention(self, edges): h = torch.cat([edges.src['h'], edges.dst['h']], dim=1) e = self.leakyrelu(self.a(h)) return {'e': e} def forward(self, g, h): with g.local_scope(): g.ndata['h'] = self.W(h) g.apply_edges(self.edge_attention) g.edata['a'] = torch.softmax(g.edata['e'], dim=1) g.update_all(fn.u_mul_e('h', 'a', 'm'), fn.sum('m', 'h_new')) return g.ndata['h_new'] ``` 该实现展示了图注意力机制如何聚合邻居节点信息,通过可学习的注意力权重实现自适应特征融合[^4]。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值