知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 KL散度公式变化

知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 KL散度公式变化

flyfish

对于两个概率分布 PPP(真实分布)和 QQQ(模型预测分布),KL散度的定义是:
DKL(P∥Q)=∑xP(x)log⁡(P(x)Q(x)) D_{KL}(P \| Q) = \sum_{x} P(x) \log\left( \frac{P(x)}{Q(x)} \right) DKL(PQ)=xP(x)log(Q(x)P(x))

怎么就变成如下公式呢

DKL(P∥Q)=∑xP(x)log⁡P(x)−∑xP(x)log⁡Q(x) D_{KL}(P \| Q) = \sum_{x} P(x) \log P(x) - \sum_{x} P(x) \log Q(x) DKL(PQ)=xP(x)logP(x)xP(x)logQ(x)

商的对数等于对数的差

“商的对数等于对数的差”就是:两个数相除的对数,等于这两个数分别取对数后再相减。

log⁡c(ab)=log⁡ca−log⁡cb\log_c \left( \frac{a}{b} \right) = \log_c a - \log_c blogc(ba)=logcalogcb

先搞懂“对数”到底是啥

对数其实是“指数的反过来”。比如“以c为底a的对数”(写成log⁡ca\log_c alogca),本质是在问:“c的多少次方等于a?”
假设这个“多少次方”是x,那就能写成:
x=log⁡cax = \log_c ax=logca
反过来,也可以说:cx=ac^x = acx=a(c的x次方等于a)。

给两个数的对数起个名字

现在有两个正数a和b(因为对数里的数必须是正数)。
按上面的逻辑:

  • 设“c的多少次方等于a”是x,也就是x=log⁡cax = \log_c ax=logca,所以a=cxa = c^xa=cx
  • 设“c的多少次方等于b”是y,也就是y=log⁡cby = \log_c by=logcb,所以b=cyb = c^yb=cy

看看a除以b等于啥

我们想算a÷ba \div ba÷b(也就是ab\frac{a}{b}ba),代入上面的a=cxa = c^xa=cxb=cyb = c^yb=cy,得到:
ab=cxcy\frac{a}{b} = \frac{c^x}{c^y}ba=cycx

这里用到一个指数的常识:同底数的幂相除,底数不变,指数相减。比如25÷23=25−3=222^5 \div 2^3 = 2^{5-3} = 2^225÷23=253=22
所以cxcy=cx−y\frac{c^x}{c^y} = c^{x - y}cycx=cxy,也就是:
ab=cx−y\frac{a}{b} = c^{x - y}ba=cxy

把上面的结果转成对数

现在看ab=cx−y\frac{a}{b} = c^{x - y}ba=cxy,这句话用对数的话来说就是:“c的(x - y)次方等于ab\frac{a}{b}ba”。
按对数的定义,“c的多少次方等于ab\frac{a}{b}ba”,这个“多少次方”就是log⁡c(ab)\log_c \left( \frac{a}{b} \right)logc(ba)
所以:
log⁡c(ab)=x−y\log_c \left( \frac{a}{b} \right) = x - ylogc(ba)=xy

换回原来的对数

前面已经知道x=log⁡cax = \log_c ax=logcay=log⁡cby = \log_c by=logcb,代入上面的式子,得到:
log⁡c(ab)=log⁡ca−log⁡cb\log_c \left( \frac{a}{b} \right) = \log_c a - \log_c blogc(ba)=logcalogcb

这样两个数相除的对数,等于这两个数的对数相减。

求和可以和加法、乘法(常数倍)交换顺序

求和的线性性质是求和可以和加法、乘法(常数倍)“交换顺序”。

规则

f(x)f(x)f(x)g(x)g(x)g(x) 是任意两个关于 xxx 的函数,ccc 是任意常数(与 xxx 无关),则:

  1. 对加法的分配性
    两个函数相加后的求和,等于两个函数分别求和后再相加。
    公式:
    ∑x[f(x)+g(x)]=∑xf(x)+∑xg(x) \sum_{x} \left[ f(x) + g(x) \right] = \sum_{x} f(x) + \sum_{x} g(x) x[f(x)+g(x)]=xf(x)+xg(x)

  2. 对常数倍的分配性
    常数乘以函数后的求和,等于常数乘以该函数的求和。
    公式:
    ∑x[c⋅f(x)]=c⋅∑xf(x) \sum_{x} \left[ c \cdot f(x) \right] = c \cdot \sum_{x} f(x) x[cf(x)]=cxf(x)

“乘法分配律”(如 c⋅(a+b)=c⋅a+c⋅bc \cdot (a + b) = c \cdot a + c \cdot bc(a+b)=ca+cb),求和的线性性质本质是“求和符号对加法和常数倍的分配”。

比如:

  • xxx 只能取1和2,f(1)=2f(1)=2f(1)=2f(2)=3f(2)=3f(2)=3g(1)=1g(1)=1g(1)=1g(2)=4g(2)=4g(2)=4,则:
    左边 ∑[f(x)+g(x)]=(2+1)+(3+4)=3+7=10\sum [f(x)+g(x)] = (2+1)+(3+4) = 3+7=10[f(x)+g(x)]=(2+1)+(3+4)=3+7=10
    右边 ∑f(x)+∑g(x)=(2+3)+(1+4)=5+5=10\sum f(x) + \sum g(x) = (2+3)+(1+4)=5+5=10f(x)+g(x)=(2+3)+(1+4)=5+5=10,两边相等,验证了第一条规则。

  • c=2c=2c=2,则:
    左边 ∑[2⋅f(x)]=2⋅2+2⋅3=4+6=10\sum [2 \cdot f(x)] = 2 \cdot 2 + 2 \cdot 3 = 4 + 6 = 10[2f(x)]=22+23=4+6=10
    右边 2⋅∑f(x)=2⋅(2+3)=2⋅5=102 \cdot \sum f(x) = 2 \cdot (2+3) = 2 \cdot 5 = 102f(x)=2(2+3)=25=10,两边相等,验证了第二条规则。

具体到KL散度

对数性质,将KL散度写成:
DKL(P∥Q)=∑xP(x)⋅[log⁡P(x)−log⁡Q(x)] D_{KL}(P \| Q) = \sum_{x} P(x) \cdot \left[ \log P(x) - \log Q(x) \right] DKL(PQ)=xP(x)[logP(x)logQ(x)]

现在要拆分成两个求和项,步骤如下:

步骤1:先展开括号里的乘法

根据乘法分配律(即a⋅(b−c)=a⋅b−a⋅ca \cdot (b - c) = a \cdot b - a \cdot ca(bc)=abac),括号里的减法可以和外面的P(x)P(x)P(x)分别相乘:
P(x)⋅[log⁡P(x)−log⁡Q(x)]=P(x)⋅log⁡P(x)−P(x)⋅log⁡Q(x) P(x) \cdot \left[ \log P(x) - \log Q(x) \right] = P(x) \cdot \log P(x) - P(x) \cdot \log Q(x) P(x)[logP(x)logQ(x)]=P(x)logP(x)P(x)logQ(x)

步骤2:应用求和的线性性质

将上式代入KL散度的表达式,得到:
DKL(P∥Q)=∑x[P(x)log⁡P(x)−P(x)log⁡Q(x)] D_{KL}(P \| Q) = \sum_{x} \left[ P(x) \log P(x) - P(x) \log Q(x) \right] DKL(PQ)=x[P(x)logP(x)P(x)logQ(x)]

此时,求和符号后面是两个项的差(即f(x)−g(x)f(x) - g(x)f(x)g(x)
其中
f(x)=P(x)log⁡P(x)f(x) = P(x)\log P(x)f(x)=P(x)logP(x)
g(x)=P(x)log⁡Q(x)g(x) = P(x)\log Q(x)g(x)=P(x)logQ(x))。

根据求和的线性性质∑(f−g)=∑f−∑g\sum (f - g) = \sum f - \sum g(fg)=fg,可以直接拆分:

∑x[P(x)log⁡P(x)−P(x)log⁡Q(x)]=∑xP(x)log⁡P(x)−∑xP(x)log⁡Q(x) \sum_{x} \left[ P(x) \log P(x) - P(x) \log Q(x) \right] = \sum_{x} P(x) \log P(x) - \sum_{x} P(x) \log Q(x) x[P(x)logP(x)P(x)logQ(x)]=xP(x)logP(x)xP(x)logQ(x)

### 知识蒸馏中的KL公式知识蒸馏过程中,为了使学生模型更好地逼近教师模型的行为,常用KL(Kullback-Leibler Divergence)衡量两个概率分布之间的差异。具体来说,在知识蒸馏场景下,KL被用来比较教师模型产生的软标签与学生模型预测的概率分布。 对于给定的一组输入样本 \(x\) ,设教师模型输出经过softmax层后的类别概率分布为 \(T(y|x)\),而学生模型对应的输出分布记作 \(S(y|x)\) 。引入温参数 \(\tau\) 来调整这两个分布的锐利程,则有: \[ T_{\text{soft}}(y|x)=\frac{\exp(z^{t}_i/\tau)}{\sum_j\exp(z^{t}_j/\tau)}, S_{\text{soft}}(y|x)=\frac{\exp(z^{s}_i/\tau)}{\sum_j\exp(z^{s}_j/\tau)} \] 其中 \(z_i^t, z_i^s\) 分别表示来自教师和学生的未归一化logits值[^1]。 此时,针对单个样例\(x\) 的KL定义如下: \[D_{KL}(T_{\text{soft}}, S_{\text{soft}}|x) = \sum_y T_{\text{soft}}(y|x)\cdot[\ln(T_{\text{soft}}(y|x))-\ln(S_{\text{soft}}(y|x))] \] 最终损失函数可由硬标签交叉熵项加上加权后的KL组成,以平衡原始任务的学习与模仿教师的任务[^5]。 ```python import torch.nn.functional as F def knowledge_distillation_loss(student_logits, teacher_logits, temperature=2.0, alpha=0.5): student_soft = F.log_softmax(student_logits / temperature, dim=-1) teacher_soft = F.softmax(teacher_logits / temperature, dim=-1) kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2) hard_labels = F.cross_entropy(student_logits, labels) total_loss = alpha * kd_loss + (1 - alpha) * hard_labels return total_loss ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值