🧠 首先搞清楚 LoRA 是怎么做微调的
我们原来要训练的参数矩阵是 WWW,但 LoRA 说:
别动 W,我在它旁边加一个低秩矩阵 ΔW=UV\Delta W = UVΔW=UV,只训练这个部分!
也就是说,LoRA 用一个新的权重矩阵:
W′=W+UV W' = W + UV W′=W+UV
只训练 UUU 和 VVV,WWW 不动。
📦 所以前向传播其实用的是:
模型输入x⟶W′x=Wx+UVx⟶输出⟶L \text{模型输入}x \longrightarrow W'x = Wx + UVx \longrightarrow \text{输出} \longrightarrow \mathcal{L} 模型输入x⟶W′x=Wx+UVx⟶输出⟶L
在这个过程中,损失函数 L\mathcal{L}L 是基于 W+UVW + UVW+UV 来计算的。
🔁 反向传播的时候怎么求梯度?
LoRA 要训练的是 UUU 和 VVV,所以我们要算:
∂L∂U和∂L∂V \frac{\partial \mathcal{L}}{\partial U} \quad \text{和} \quad \frac{\partial \mathcal{L}}{\partial V} ∂U∂L和∂V∂L
但问题是:损失函数 L\mathcal{L}L 不是直接依赖 UUU 和 VVV,而是依赖 UVUVUV
所以要用链式法则,先对 UVUVUV 求导,然后传播回 UUU、VVV。而对UV求导等价于对WWW求导
✅ 关键点来了
我们记:
∂L∂W=G \frac{\partial \mathcal{L}}{\partial W} = G ∂W∂L=G
这个 GGG 就是“如果我们在做全量微调,该怎么更新 WWW 的梯度”。
LoRA 说:
“虽然我不更新 WWW,但我要更新的是 UVUVUV。所以我也可以用这个 GGG 来指导我怎么更新 UUU 和 VVV。”
于是我们得到:
∂L∂U=GV⊤,∂L∂V=U⊤G \frac{\partial \mathcal{L}}{\partial U} = G V^\top, \quad \frac{\partial \mathcal{L}}{\partial V} = U^\top G ∂U∂L=GV⊤,∂V∂L=U⊤G
LoRA 的梯度建立在 ∂L∂W\frac{\partial \mathcal{L}}{\partial W}∂W∂L 上, 是因为它相当于“用低秩矩阵 UVUVUV 来代替全量的参数更新”, 所以梯度传播也必须从 ∂L∂W\frac{\partial \mathcal{L}}{\partial W}∂W∂L 开始。
LoRA 往往只是显存不足的无奈之选,因为一般情况下全量微调的效果都会优于 LoRA,所以如果算力足够并且要追求效果最佳时,请优先选择全量微调。
使用 LoRA 的另一个场景是有大量的微型定制化需求,要存下非常多的微调结果,此时使用 LoRA 能减少储存成本。
🔍 为什么
为什么 ∂L∂W\frac{\partial \mathcal{L}}{\partial W}∂W∂L,就是对 UVUVUV 的梯度?
换句话说:LoRA 中的 W′=W+UVW' = W + UVW′=W+UV,那我们训练时不是更新 WWW,只更新 UVUVUV,那为什么还能用 ∂L∂W\frac{\partial \mathcal{L}}{\partial W}∂W∂L 来指导 UUU 和 VVV 的更新呢?
✅ 答案是:因为前向传播中 W+UVW + UVW+UV 是一起作为整体参与运算的
所以:
∂L∂W=∂L∂(W+UV)=∂L∂(UV) \frac{\partial \mathcal{L}}{\partial W} = \frac{\partial \mathcal{L}}{\partial (W + UV)} = \frac{\partial \mathcal{L}}{\partial (UV)} ∂W∂L=∂(W+UV)∂L=∂(UV)∂L
这是因为:
- 我们的模型使用的是 W+UVW + UVW+UV
- 所以损失函数 L\mathcal{L}L 是以 W+UVW + UVW+UV 为输入计算出来的
- 那么对 WWW 求导,其实是对这个整体求导
- 而因为 WWW 是固定的(不训练,看作常数),所以梯度全部由 UVUVUV 来承接
- 本来我们应该更新 WWW:
W←W−η∂L∂W W \leftarrow W - \eta \frac{\partial \mathcal{L}}{\partial W} W←W−η∂W∂L - 现在我们不动 WWW,让 UVUVUV 来“做这个事情”:
W+UV←W+UV−η⋅(LoRA方向上的梯度) W + UV \leftarrow W + UV - \eta \cdot \left(\text{LoRA方向上的梯度}\right) W+UV←W+UV−η⋅(LoRA方向上的梯度)
所以如果要算 UVUVUV 的导数,就是算 ∂L∂W\frac{\partial \mathcal{L}}{\partial W}∂W∂L