LLM数学推导——Transformer问题集——注意力机制——基础组件

Q1:推导标准点积注意力公式 \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V 的梯度计算

标准点积注意力公式是啥?

标准点积注意力公式 \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 是 Transformer 架构的核心组件,用于计算输入序列中元素的语义关联权重。

  • 核心作用:通过查询矩阵 Q、键矩阵 K、值矩阵 V 的交互,为每个查询动态分配对键值对的注意力权重,实现对输入信息的加权聚合。例如,在语言模型中,可让模型聚焦于当前词的上下文关键词。
  • 关键创新:引入 \sqrt{d_k} 缩放点积结果,避免高维空间中数值过大导致 Softmax 梯度消失,提升训练稳定性。
梯度推导过程

设 A = \frac{QK^T}{\sqrt{d_k}}\Omega = \text{Softmax}(A),输出 O = \Omega V。梯度推导需解决三个核心问题:

  1. 对 V 的梯度: 由矩阵乘法求导规则直接得 \frac{\partial O}{\partial V} = \Omega^T,体现 “值矩阵加权” 的线性特性。

  2. 对 Q 的梯度

    • 通过链式法则分解为 O \leftarrow \Omega \leftarrow A \leftarrow Q
    • A 对 Q 的导数为 \frac{K_j^T}{\sqrt{d_k}},结合 Softmax 导数性质 \frac{\partial \Omega_{ij}}{\partial A_{pq}} = \Omega_{ij}(\delta_{ip} - \Omega_{pq}),最终化简为:\frac{\partial O}{\partial Q} = \frac{1}{\sqrt{d_k}} (\Omega - \Omega \Omega^T) V K其中 (\Omega - \Omega \Omega^T)反映权重分布与均值的差异,调节梯度方向。
  3. 对 K 的梯度: 推导逻辑与 Q 对称,结果为:\frac{\partial O}{\partial K} = \frac{1}{\sqrt{d_k}} (\Omega - \Omega \Omega^T) V^T Q 差异源于 K 在点积中的转置位置,梯度包含 V^T Q 而非 V K。

在 LLM 中的梯度应用

在大语言模型训练中,梯度计算是反向传播优化的核心,直接影响参数更新效率:

  • 参数更新机制: 梯度 \frac{\partial O}{\partial Q/K/V} 指示投影矩阵的调整方向。例如,若某查询 - 键对的梯度较大,说明其关联权重需重点优化,模型会增强该语义关联的学习。
  • 训练稳定性优化
    • 缩放因子 \sqrt{d_k}通过控制梯度幅度避免消失,尤其在高维特征(如 d_k=1024)时效果显著。
    • 权重差异矩阵 (\Omega - \Omega \Omega^T) 具有 “抗过拟合” 作用:当注意力集中于少数位置时,梯度幅度自动减小,防止模型过度依赖局部信息。
代码示例与解析
import torch  

# 定义可学习的投影矩阵  
Q = torch.randn(8, 64, requires_grad=True)  # 8个查询,每个维度64  
K = torch.randn(8, 64, requires_grad=True)  # 8个键,维度与Q一致  
V = torch.randn(8, 128, requires_grad=True)  # 8个值,维度128  
d_k = Q.size(1)  # d_k=64  

# 前向传播:计算注意力输出  
A = torch.matmul(Q, K.T) / torch.sqrt(torch.tensor(d_k))  # 缩放点积:(8,8)  
Ω = torch.softmax(A, dim=1)  # 注意力权重:(8,8)  
O = torch.matmul(Ω, V)  # 输出:(8,128)  

# 反向传播:模拟损失函数梯度(此处用随机目标)  
loss = (O - torch.randn_like(O)).mean()  
loss.backward()  

# 打印梯度范数(反映更新幅度)  
print("Q梯度范数:", torch.norm(Q.grad))  
print("K梯度范数:", torch.norm(K.grad))  
print("V梯度范数:", torch.norm(V.grad))  

代码解析

  • 可导投影层Q/K/V通常由模型参数生成(如线性层),此处用随机张量模拟可学习参数。
  • 缩放点积核心逻辑:通过 K.T 转置实现查询 - 键的成对匹配,除以 \sqrt{d_k} 确保数值稳定。
  • 梯度观察:打印梯度范数可监控参数更新强度,若某矩阵范数异常增大,需调整学习率或启用正则化。
总结

标准点积注意力的梯度推导揭示了其优化本质:通过 Softmax 的 “竞争机制”(\Omega - \Omega \Omega^T)和缩放操作,平衡了注意力分布的集中性与梯度有效性。在 LLM 中,这一机制使模型能动态调整语义关联强度,同时通过反向传播高效优化参数。代码实现则借助 PyTorch 的自动微分,将复杂的矩阵导数计算抽象为简洁的 API 调用,体现了理论与工程的深度结合。


Q2:证明缩放因子 \sqrt{d_k} 对梯度方差稳定性的作用(假设 Q, K \sim N(0, 1)

缩放因子 \sqrt{d_k} 是啥?

在标准点积注意力机制里,会用到公式 \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,这里的 \sqrt{d_k} 就是缩放因子。其中 Q 是查询矩阵,K 是键矩阵,d_k 是 Q 和 K 的维度。这个缩放因子的出现,主要是为了解决高维情况下点积结果数值过大的问题。要是没有它,后续的计算就容易出状况,比如梯度不稳定,这会影响模型训练的效果和效率。

证明过程
1. 先看未缩放时 QK^T的元素分布

假设 Q 和 K 的元素都服从标准正态分布 N(0, 1)。设 Q 是 m \times d_k的矩阵,K 是 n \times d_k 的矩阵,那么 QK^T 中元素 (QK^T)_{ij} = \sum_{l = 1}^{d_k} Q_{il}K_{jl}。 根据正态分布的性质,如果 X_1, X_2, \cdots, X_n相互独立且都服从 N(0, 1),那么 \sum_{i = 1}^{n}X_i服从 N(0, n)。在这里,因为 Q_{il} 和 K_{jl} 都服从 N(0, 1) 且相互独立,所以 (QK^T)_{ij} 服从 N(0, d_k)。这意味着,随着 d_k 的增大,QK^T 元素的方差也会增大。

2. 再看缩放后的 \frac{QK^T}{\sqrt{d_k}}的元素分布

当我们用 \sqrt{d_k} 对 QK^T 进行缩放时,得到 \frac{(QK^T)_{ij}}{\sqrt{d_k}}。根据正态分布的性质,若 X \sim N(0, \sigma^2),那么 \frac{X}{c} \sim N(0, \frac{\sigma^2}{c^2})(c 为常数)。所以 \frac{(QK^T)_{ij}}{\sqrt{d_k}} 服从 N(0, 1)。也就是说,不管 d_k 怎么变,缩放后元素的方差始终稳定在 1。

3. 接着分析 Softmax 函数的输入对梯度的影响

Softmax 函数用于将输入转换为概率分布,公式为 \text{Softmax}(x)_i = \frac{e^{x_i}}{\sum_{j = 1}^{n}e^{x_j}}。当输入的方差较大时,Softmax 函数的输出会趋近于 one - hot 向量。这是因为方差大意味着输入值之间的差异大,经过指数运算后,大的值会变得更大,小的值会变得更小。 对 Softmax 函数求导,其导数与输入值和输出值都有关系。当 Softmax 输出趋近于 one - hot 向量时,梯度会变得非常小,出现梯度消失的问题。而当我们使用 \sqrt{d_k} 缩放输入后,输入的方差稳定,Softmax 函数的输出分布更加合理,梯度也就更加稳定。

4. 最后看梯度方差的变化

在反向传播过程中,梯度的计算与 Softmax 函数的输出有关。未缩放时,由于 QK^T元素方差随 d_k 增大而增大,Softmax 输出不稳定,导致梯度方差也不稳定。而缩放后,\frac{QK^T}{\sqrt{d_k}} 元素方差稳定,Softmax 输出稳定,梯度方差也随之稳定。

在 LLM 中的作用

在大语言模型(LLM)里,缩放因子 \sqrt{d_k} 对梯度方差稳定性的作用至关重要。

  • 提升训练稳定性:在训练过程中,如果梯度方差不稳定,可能会出现梯度消失或梯度爆炸的问题。梯度消失会使模型参数更新缓慢,甚至停滞不前;梯度爆炸则会使参数更新过大,导致模型无法收敛。使用 \sqrt{d_k} 缩放后,梯度方差稳定,能有效避免这些问题,让模型训练更加稳定。
  • 加速收敛:稳定的梯度能让模型在每次迭代中更合理地更新参数,避免了因梯度异常而进行的无效更新。这样,模型可以更快地收敛到最优解,提高训练效率。
代码示例
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# 定义维度范围
d_k_values = [16, 32, 64, 128, 256, 512, 1024]
num_trials = 100

# 存储未缩放和缩放后的梯度方差
unscaled_variances = []
scaled_variances = []

for d_k in d_k_values:
    unscaled_grads = []
    scaled_grads = []

    for _ in range(num_trials):
        # 生成 Q 和 K 矩阵
        Q = torch.randn(1, d_k)
        K = torch.randn(1, d_k)

        # 未缩放的点积
        unscaled_dot_product = torch.matmul(Q, K.T)
        unscaled_softmax = F.softmax(unscaled_dot_product, dim=1)
        # 模拟损失函数,这里简单用求和
        unscaled_loss = unscaled_softmax.sum()
        unscaled_loss.backward(retain_graph=True)
        unscaled_grads.append(Q.grad.view(-1).clone())
        Q.grad.zero_()

        # 缩放后的点积
        scaled_dot_product = torch.matmul(Q, K.T) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        scaled_softmax = F.softmax(scaled_dot_product, dim=1)
        scaled_loss = scaled_softmax.sum()
        scaled_loss.backward()
        scaled_grads.append(Q.grad.view(-1).clone())
        Q.grad.zero_()

    # 计算梯度方差
    unscaled_variances.append(torch.var(torch.stack(unscaled_grads)))
    scaled_variances.append(torch.var(torch.stack(scaled_grads)))

# 绘制方差变化图
plt.plot(d_k_values, unscaled_variances, label='Unscaled')
plt.plot(d_k_values, scaled_variances, label='Scaled')
plt.xlabel('$d_k$')
plt.ylabel('Gradient Variance')
plt.title('Effect of Scaling Factor on Gradient Variance')
plt.legend()
plt.show()
代码解释
  • 数据生成:代码中通过 torch.randn 函数生成服从标准正态分布的 Q 和 K 矩阵。
  • 未缩放和缩放计算:分别计算未缩放的QK^T 和缩放后的 \frac{QK^T}{\sqrt{d_k}},然后对它们应用 Softmax 函数,模拟损失函数并进行反向传播计算梯度。
  • 方差计算与绘图:多次试验后,计算未缩放和缩放情况下梯度的方差,并使用 matplotlib 绘制方差随 d_k变化的曲线。从图中可以直观地看到,未缩放时梯度方差随 d_k 增大而增大,而缩放后梯度方差保持相对稳定。
总结

缩放因子 \sqrt{d_k} 通过将 QK^T 元素的方差稳定在 1,避免了 Softmax 函数输出的极端情况,从而保证了梯度方差的稳定性。在大语言模型训练中,这一稳定性对于避免梯度消失和爆炸、提高训练效率和模型性能具有重要意义。代码示例则直观地展示了缩放因子对梯度方差稳定性的作用。


Q3:多头注意力的并行计算复杂度分析(FLOPs 公式推导)

多头注意力是啥?

多头注意力机制是 Transformer 架构中的关键组件,它允许模型在不同的表示子空间中并行地关注输入序列的不同部分。标准的注意力机制可以表示为 \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,而多头注意力则是将 Q、K、V 分别线性投影到 h 个低维子空间(头),并行计算每个头的注意力,最后将结果拼接并再次进行线性投影。

多头注意力的主要作用是增强模型对不同特征和模式的捕捉能力。通过多个头并行工作,模型可以从不同的角度关注输入序列,从而提取更丰富的语义信息。

并行计算复杂度分析及 FLOPs 公式推导
1. 单个头的计算复杂度

首先,我们来分析单个头的注意力计算复杂度。假设输入序列的长度为 n,查询、键和值矩阵的维度分别为 d_qd_k 和 d_v,在多头注意力中,每个头的维度通常为 d_{head}=\frac{d_k}{h}=\frac{d_v}{h}(这里 h 是头的数量)。

  • 计算 QK^T 计算 QK^T 是一个矩阵乘法操作,Q 是 n \times d_{head} 的矩阵,K^T 是 d_{head} \times n 的矩阵。矩阵乘法的计算复杂度为 O(n\times d_{head}\times n) = O(n^2d_{head})。具体来说,对于 QK^T 中的每个元素,需要进行 d_{head} 次乘法和 d_{head}-1 次加法,总共 n\times n 个元素,所以乘法和加法的操作次数大约为 n^2d_{head} 次。

  • 缩放操作: 对 QK^T 进行缩放,即除以 \sqrt{d_{head}},这只需要 n^2 次除法操作,相对于矩阵乘法来说复杂度较低,在大 n 和 d_{head} 的情况下可以忽略不计。

  • Softmax 操作: Softmax 操作是对矩阵的每一行进行归一化。对于每一行(共 n 行),需要进行 n 次指数运算、n - 1 次加法和 n 次除法。指数运算的复杂度相对较高,但在实际计算中,通常可以将其视为常数时间操作。所以 Softmax 操作的复杂度为 O(n^2),同样在大 n 和 d_{head}的情况下相对于矩阵乘法可忽略。

  • 计算\text{Softmax}(QK^T/\sqrt{d_{head}})V 这又是一个矩阵乘法操作,\text{Softmax}(QK^T/\sqrt{d_{head}}) 是 n \times n 的矩阵,V 是 n \times d_{head} 的矩阵,计算复杂度为 O(n\times n\times d_{head}) = O(n^2d_{head})

    综合来看,单个头的注意力计算复杂度主要由矩阵乘法决定,为 O(n^2d_{head})

2. h 个头的并行计算复杂度

由于 h 个头是并行计算的,所以 h 个头的计算复杂度仍然是 O(n^2d_{head}),因为并行计算可以同时处理多个头,而不是依次处理。

3. 线性投影的计算复杂度

在多头注意力中,还需要

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

墨顿

唵嘛呢叭咪吽

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值