Q1:推导标准点积注意力公式
的梯度计算
标准点积注意力公式是啥?
标准点积注意力公式 是 Transformer 架构的核心组件,用于计算输入序列中元素的语义关联权重。
- 核心作用:通过查询矩阵 Q、键矩阵 K、值矩阵 V 的交互,为每个查询动态分配对键值对的注意力权重,实现对输入信息的加权聚合。例如,在语言模型中,可让模型聚焦于当前词的上下文关键词。
- 关键创新:引入
缩放点积结果,避免高维空间中数值过大导致 Softmax 梯度消失,提升训练稳定性。
梯度推导过程
设 ,
,输出
。梯度推导需解决三个核心问题:
-
对 V 的梯度: 由矩阵乘法求导规则直接得
,体现 “值矩阵加权” 的线性特性。
-
对 Q 的梯度:
- 通过链式法则分解为
。
- A 对 Q 的导数为
,结合 Softmax 导数性质
,最终化简为:
其中
反映权重分布与均值的差异,调节梯度方向。
- 通过链式法则分解为
-
对 K 的梯度: 推导逻辑与 Q 对称,结果为:
差异源于 K 在点积中的转置位置,梯度包含
而非 V K。
在 LLM 中的梯度应用
在大语言模型训练中,梯度计算是反向传播优化的核心,直接影响参数更新效率:
- 参数更新机制: 梯度
指示投影矩阵的调整方向。例如,若某查询 - 键对的梯度较大,说明其关联权重需重点优化,模型会增强该语义关联的学习。
- 训练稳定性优化:
- 缩放因子
通过控制梯度幅度避免消失,尤其在高维特征(如
)时效果显著。
- 权重差异矩阵
具有 “抗过拟合” 作用:当注意力集中于少数位置时,梯度幅度自动减小,防止模型过度依赖局部信息。
- 缩放因子
代码示例与解析
import torch
# 定义可学习的投影矩阵
Q = torch.randn(8, 64, requires_grad=True) # 8个查询,每个维度64
K = torch.randn(8, 64, requires_grad=True) # 8个键,维度与Q一致
V = torch.randn(8, 128, requires_grad=True) # 8个值,维度128
d_k = Q.size(1) # d_k=64
# 前向传播:计算注意力输出
A = torch.matmul(Q, K.T) / torch.sqrt(torch.tensor(d_k)) # 缩放点积:(8,8)
Ω = torch.softmax(A, dim=1) # 注意力权重:(8,8)
O = torch.matmul(Ω, V) # 输出:(8,128)
# 反向传播:模拟损失函数梯度(此处用随机目标)
loss = (O - torch.randn_like(O)).mean()
loss.backward()
# 打印梯度范数(反映更新幅度)
print("Q梯度范数:", torch.norm(Q.grad))
print("K梯度范数:", torch.norm(K.grad))
print("V梯度范数:", torch.norm(V.grad))
代码解析:
- 可导投影层:
通常由模型参数生成(如线性层),此处用随机张量模拟可学习参数。
- 缩放点积核心逻辑:通过
转置实现查询 - 键的成对匹配,除以
确保数值稳定。
- 梯度观察:打印梯度范数可监控参数更新强度,若某矩阵范数异常增大,需调整学习率或启用正则化。
总结
标准点积注意力的梯度推导揭示了其优化本质:通过 Softmax 的 “竞争机制”()和缩放操作,平衡了注意力分布的集中性与梯度有效性。在 LLM 中,这一机制使模型能动态调整语义关联强度,同时通过反向传播高效优化参数。代码实现则借助 PyTorch 的自动微分,将复杂的矩阵导数计算抽象为简洁的 API 调用,体现了理论与工程的深度结合。
Q2:证明缩放因子
对梯度方差稳定性的作用(假设
)
缩放因子
是啥?
在标准点积注意力机制里,会用到公式 ,这里的
就是缩放因子。其中 Q 是查询矩阵,K 是键矩阵,
是 Q 和 K 的维度。这个缩放因子的出现,主要是为了解决高维情况下点积结果数值过大的问题。要是没有它,后续的计算就容易出状况,比如梯度不稳定,这会影响模型训练的效果和效率。
证明过程
1. 先看未缩放时
的元素分布
假设 Q 和 K 的元素都服从标准正态分布 。设 Q 是
的矩阵,K 是
的矩阵,那么
中元素
。 根据正态分布的性质,如果
相互独立且都服从
,那么
服从
。在这里,因为
和
都服从
且相互独立,所以
服从
。这意味着,随着
的增大,
元素的方差也会增大。
2. 再看缩放后的
的元素分布
当我们用 对
进行缩放时,得到
。根据正态分布的性质,若
,那么
(c 为常数)。所以
服从
。也就是说,不管
怎么变,缩放后元素的方差始终稳定在 1。
3. 接着分析 Softmax 函数的输入对梯度的影响
Softmax 函数用于将输入转换为概率分布,公式为 。当输入的方差较大时,Softmax 函数的输出会趋近于 one - hot 向量。这是因为方差大意味着输入值之间的差异大,经过指数运算后,大的值会变得更大,小的值会变得更小。 对 Softmax 函数求导,其导数与输入值和输出值都有关系。当 Softmax 输出趋近于 one - hot 向量时,梯度会变得非常小,出现梯度消失的问题。而当我们使用
缩放输入后,输入的方差稳定,Softmax 函数的输出分布更加合理,梯度也就更加稳定。
4. 最后看梯度方差的变化
在反向传播过程中,梯度的计算与 Softmax 函数的输出有关。未缩放时,由于 元素方差随
增大而增大,Softmax 输出不稳定,导致梯度方差也不稳定。而缩放后,
元素方差稳定,Softmax 输出稳定,梯度方差也随之稳定。
在 LLM 中的作用
在大语言模型(LLM)里,缩放因子 对梯度方差稳定性的作用至关重要。
- 提升训练稳定性:在训练过程中,如果梯度方差不稳定,可能会出现梯度消失或梯度爆炸的问题。梯度消失会使模型参数更新缓慢,甚至停滞不前;梯度爆炸则会使参数更新过大,导致模型无法收敛。使用
缩放后,梯度方差稳定,能有效避免这些问题,让模型训练更加稳定。
- 加速收敛:稳定的梯度能让模型在每次迭代中更合理地更新参数,避免了因梯度异常而进行的无效更新。这样,模型可以更快地收敛到最优解,提高训练效率。
代码示例
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# 定义维度范围
d_k_values = [16, 32, 64, 128, 256, 512, 1024]
num_trials = 100
# 存储未缩放和缩放后的梯度方差
unscaled_variances = []
scaled_variances = []
for d_k in d_k_values:
unscaled_grads = []
scaled_grads = []
for _ in range(num_trials):
# 生成 Q 和 K 矩阵
Q = torch.randn(1, d_k)
K = torch.randn(1, d_k)
# 未缩放的点积
unscaled_dot_product = torch.matmul(Q, K.T)
unscaled_softmax = F.softmax(unscaled_dot_product, dim=1)
# 模拟损失函数,这里简单用求和
unscaled_loss = unscaled_softmax.sum()
unscaled_loss.backward(retain_graph=True)
unscaled_grads.append(Q.grad.view(-1).clone())
Q.grad.zero_()
# 缩放后的点积
scaled_dot_product = torch.matmul(Q, K.T) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
scaled_softmax = F.softmax(scaled_dot_product, dim=1)
scaled_loss = scaled_softmax.sum()
scaled_loss.backward()
scaled_grads.append(Q.grad.view(-1).clone())
Q.grad.zero_()
# 计算梯度方差
unscaled_variances.append(torch.var(torch.stack(unscaled_grads)))
scaled_variances.append(torch.var(torch.stack(scaled_grads)))
# 绘制方差变化图
plt.plot(d_k_values, unscaled_variances, label='Unscaled')
plt.plot(d_k_values, scaled_variances, label='Scaled')
plt.xlabel('$d_k$')
plt.ylabel('Gradient Variance')
plt.title('Effect of Scaling Factor on Gradient Variance')
plt.legend()
plt.show()
代码解释
- 数据生成:代码中通过
torch.randn函数生成服从标准正态分布的 Q 和 K 矩阵。 - 未缩放和缩放计算:分别计算未缩放的
和缩放后的
,然后对它们应用 Softmax 函数,模拟损失函数并进行反向传播计算梯度。
- 方差计算与绘图:多次试验后,计算未缩放和缩放情况下梯度的方差,并使用
matplotlib绘制方差随变化的曲线。从图中可以直观地看到,未缩放时梯度方差随
增大而增大,而缩放后梯度方差保持相对稳定。
总结
缩放因子 通过将
元素的方差稳定在 1,避免了 Softmax 函数输出的极端情况,从而保证了梯度方差的稳定性。在大语言模型训练中,这一稳定性对于避免梯度消失和爆炸、提高训练效率和模型性能具有重要意义。代码示例则直观地展示了缩放因子对梯度方差稳定性的作用。
Q3:多头注意力的并行计算复杂度分析(FLOPs 公式推导)
多头注意力是啥?
多头注意力机制是 Transformer 架构中的关键组件,它允许模型在不同的表示子空间中并行地关注输入序列的不同部分。标准的注意力机制可以表示为 ,而多头注意力则是将 Q、K、V 分别线性投影到 h 个低维子空间(头),并行计算每个头的注意力,最后将结果拼接并再次进行线性投影。
多头注意力的主要作用是增强模型对不同特征和模式的捕捉能力。通过多个头并行工作,模型可以从不同的角度关注输入序列,从而提取更丰富的语义信息。
并行计算复杂度分析及 FLOPs 公式推导
1. 单个头的计算复杂度
首先,我们来分析单个头的注意力计算复杂度。假设输入序列的长度为 n,查询、键和值矩阵的维度分别为 、
和
,在多头注意力中,每个头的维度通常为
(这里 h 是头的数量)。
-
计算
: 计算
是一个矩阵乘法操作,Q 是
的矩阵,
是
的矩阵。矩阵乘法的计算复杂度为
。具体来说,对于
中的每个元素,需要进行
次乘法和
次加法,总共
个元素,所以乘法和加法的操作次数大约为
次。
-
缩放操作: 对
进行缩放,即除以
,这只需要
次除法操作,相对于矩阵乘法来说复杂度较低,在大 n 和
的情况下可以忽略不计。
-
Softmax 操作: Softmax 操作是对矩阵的每一行进行归一化。对于每一行(共 n 行),需要进行 n 次指数运算、
次加法和 n 次除法。指数运算的复杂度相对较高,但在实际计算中,通常可以将其视为常数时间操作。所以 Softmax 操作的复杂度为
,同样在大 n 和
的情况下相对于矩阵乘法可忽略。
-
计算
: 这又是一个矩阵乘法操作,
是
的矩阵,V 是
的矩阵,计算复杂度为
。
综合来看,单个头的注意力计算复杂度主要由矩阵乘法决定,为
。
2. h 个头的并行计算复杂度
由于 h 个头是并行计算的,所以 h 个头的计算复杂度仍然是 ,因为并行计算可以同时处理多个头,而不是依次处理。
3. 线性投影的计算复杂度
在多头注意力中,还需要

最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



