多头注意力机制 python代码实现

上一章我们讲了多头注意力机制的原理,并且使用了“我请李雷帮我订了去北京的机票“这句话进行了模拟推演。本章配上具体的python代码实现,加深理解。
 


import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# ====================== 1. 准备数据 ======================

# 设置中文字体(根据系统安装的字体选择,以下是常见兼容字体)
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]  # 支持中文显示
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示异常的问题(可选)


tokens = ["我", "请", "李雷", "帮", "我", "订了", "去", "北京", "的", "机票"]
N = len(tokens)  # token 数量
embedding_dim = 8
num_heads = 4
head_dim = embedding_dim // num_heads


# 模拟 embedding 层输出 (N, D) 10 8
embeddings = torch.randn(N, embedding_dim) * 0.5



# ====================== 2. 定义线性层参数 ======================
# 8*8
W_q = torch.randn(embedding_dim, embedding_dim) * 0.3
W_k = torch.randn(embedding_dim, embedding_dim) * 0.3
W_v = torch.randn(embedding_dim, embedding_dim) * 0.3
W_o = torch.randn(embedding_dim, embedding_dim) * 0.3


# ====================== 3. 计算 Q / K / V ======================
Q = embeddings @ W_q  # (N, D)
# print(Q)
K = embeddings @ W_k
V = embeddings @ W_v
# print(Q.size(), K.size(), V.size())
# print(Q)

# 拆分多头
def split_heads(x, num_heads):
    """把 (N, D) → (num_heads, N, D/num_heads)"""
    N, D = x.shape
    #  # 第二步:交换维度 0 和 1,得到 (num_heads, N, D//num_heads)
    return x.view(N, num_heads, D // num_heads).transpose(0, 1)

# 10*8 10个token 每个被编码为8维向量 多头注意力 4头 每个8维向量被拆成4个2维向量
Q_heads = split_heads(Q, num_heads)
# torch.Size([4, 10, 2])
# print(Q_heads.shape)
K_heads = split_heads(K, num_heads)
V_heads = split_heads(V, num_heads)



# ====================== 4. 计算每个 Head 的注意力 ======================
attention_outputs = []
attention_weights = []

scale = head_dim ** 0.5

# exit()
for h in range(num_heads):
    q, k, v = Q_heads[h], K_heads[h], V_heads[h]  # (N, head_dim)

    # QK^T → (N, N) 自注意力(Self-Attention)矩阵。
    scores = (q @ k.T) / scale
    # print(scores.shape)
    # print(scores)

    # Softmax
    weights = F.softmax(scores, dim=-1)
    attention_weights.append(weights)

    # 加权求和
    out = weights @ v  # (N, head_dim)
    attention_outputs.append(out)


    print(f"\n--- Head {h+1} ---")
    print("注意力权重 (softmax 后):\n", torch.round(weights, decimals=3))


# ====================== 5. 拼接多头输出 ======================
concat = torch.cat(attention_outputs, dim=-1)  # (N, D)
final_output = concat @ W_o                    # (N, D)

print("\n=== 最终输出 ===\n", torch.round(final_output, decimals=3))


# exit()
# ====================== 6. 可视化注意力 ======================
fig, axes = plt.subplots(1, num_heads, figsize=(4 * num_heads, 3))
for h in range(num_heads):
    ax = axes[h]
    im = ax.imshow(attention_weights[h].detach().numpy(), cmap='Blues', aspect='auto')
    ax.set_title(f"Head {h+1}")
    ax.set_xticks(range(N))
    ax.set_yticks(range(N))
    ax.set_xticklabels(tokens, rotation=45)
    ax.set_yticklabels(tokens)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()

# ====================== 7. 打印每个 token 的 Top-3 注意力 ======================
print("\n=== 每个 token 在各 head 的 Top-3 注意力对象 ===")
for h in range(num_heads):
    weights = attention_weights[h]
    print(f"\n-- Head {h+1} --")
    for i in range(N):
        top3 = torch.topk(weights[i], 3)
        idxs = top3.indices.tolist()
        vals = top3.values.tolist()
        print(f"{tokens[i]} → {[f'{tokens[idx]}({val:.3f})' for idx, val in zip(idxs, vals)]}")

以上代码会输出:

  • 每个头(head)的注意力权重矩阵(数值 + 可视化 heatmap)

  • 每个 token 的前 3 个最关注对象

  • 每个 token 的最终向量输出(融合了上下文语义)


=== 每个 token 在各 head 的 Top-3 注意力对象 ===

从图中我们可以非常直观的看出来每一个token和其他token之间的关注度权重,能够明白模型是如何理解语义信息的,也可以从结果中看出来,头的侧重点不一样,token之间关注的权重会有一些细微的变化,但句子、段落表达出的总体的意思都是大差不差的。都是--我请李雷帮我订了去北京的机票
希望结合代码能够帮助你更好的理解多头注意力机制。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值