从不确定性减少视角理解KL散度
【 Transformer 系列,故事从 d k \sqrt{d_k} dk说起】
LLM这么火,Transformer厥功甚伟,某天心血来潮~,再去看看!
它长这个样子: 深入浅出 Transformer
看完后,想起了老生常谈 d k \sqrt{d_k} dk问题,必须一探究竟:Transformer 中缩放点积注意力机制探讨:除以根号 dk 理由及其影响
感觉不够清楚,还是再Review下考研概率论,有了:基于考研概率论知识解读 Transformer:为何自注意力机制要除以根号 dk,中间会涉及初始化、标准化、Sofrmax函数,于是继续
【初始化相关】:深度学习中的常见初始化方法:原理、应用与比较
【标准化系列】: 数据为什么要进行标准化:Z-标准化的神奇蜕变,带出了关联知识点: 深度 “炼丹” 术之 Batch Normalization 与 Z - 标准化:开启数据的神秘转换
【Softmax复习】:Softmax 层反向传播梯度计算实例解析,中间想到了经常配套使用的交叉熵,于是梳理了交叉熵的前世今生KL 散度:多维度解读概率分布间的隐秘 “距离”
熵与交叉熵:从不确定性角度理解 KL 散度
机器学习、深度学习关于熵你所需要知道的一切
本文核心
- 由于熵表征不确定性大小,且基于真实分布 P P P 本身编码是最“有效”的方式(即不确定性最小),所以当使用其他分布 Q Q Q 来近似 P P P 进行编码时,必然会引入更多的不确定性,也就意味着交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 肯定会比熵 H ( P ) H(P) H(P) 大。
- D K L ( P ∣ ∣ Q ) = H ( P , Q ) − H ( P ) D_{KL}(P||Q)=H(P, Q)-H(P) DKL(P∣∣Q)=H(P,Q)−H(P),即:KL散度 = 交叉熵[H(P, Q)]-熵[H(Q)],鉴于交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 大于等于熵 H ( P ) H(P) H(P),KL散度 D K L ( P ∣ ∣ Q ) D_{KL}(P||Q) DKL(P∣∣Q) 必然是非负的。
- 由于真实分布 其熵 H ( P ) H(P) H(P) 是一个固定值,所以最小化 KL 散度等价于最小化交叉熵 H ( P , Q ) H(P, Q) H(P,Q),即
min Q D K L ( P ∣ ∣ Q ) = min Q ( H ( P , Q ) − H ( P ) ) = min Q H ( P , Q ) \min_{Q} D_{KL}(P||Q)=\min_{Q}(H(P, Q)-H(P))=\min_{Q} H(P, Q) QminDKL(P∣∣Q)=Qmin(H(P,Q)−H(P))=QminH(P,Q)这就解释了为何机器/深度学习领域将交叉熵作为损失函数。
引言
在信息论与概率统计的交融领域,KL散度(Kullback - Leibler Divergence)扮演着举足轻重的角色。从不确定性减少的独特视角深入探究KL散度,不仅能揭示其本质内涵,还能为诸多实际应用提供清晰的理论支撑。而在这个过程中,理解熵、交叉熵以及它们之间的关系至关重要,尤其要明确熵是表达不确定性大小的量,并且交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 必然大于等于熵 H ( P ) H(P) H(P)。
一、信息熵:不确定性的精准度量
(一)信息熵的定义与本质
信息熵 H ( P ) H(P) H(P) 作为衡量概率分布 P P P 不确定性程度的核心指标,其数学表达式为: H ( P ) = − ∑ i P ( x i ) log P ( x i ) H(P)=-\sum_{i}P(x_i)\log P(x_i) H(P)=−i∑P(xi)logP(xi)其中, x i x_i xi 涵盖了随机变量所有可能的取值, P ( x i ) P(x_i) P(xi) 则代表取值 x i x_i xi 出现的概率。信息熵本质上量化了从分布 P P P 中随机抽取一个事件时,平均所蕴含的信息量。由于熵是表达不确定性大小的,分布的不确定性程度越高,对应的信息熵数值越大。
例如,考虑一个均匀的六面骰子,每个面向上的概率均为 1 6 \frac{1}{6} 61。在此情形下,每次掷骰子的结果都具有较高的不确定性,其信息熵也相对较大。这是因为在掷骰子之前,我们难以准确预测哪个面会朝上,每个结果都带来了较多的“意外”信息。反之,若骰子经过特殊处理,使得某一面出现的概率为1,而其他面为0,此时该骰子的结果几乎没有不确定性,信息熵也就趋近于0。
(二)信息熵的直观理解
可以将信息熵看作是对随机事件结果“混乱程度”或“不可预测性”的一种度量。以抛硬币为例,一枚标准的硬币,正面和反面出现的概率各为 0.5 0.5 0.5,其信息熵相对较高,因为在抛硬币之前,我们无法确切知晓结果是正面还是反面。而如果硬币被做了手脚,总是正面朝上,那么它的信息熵就为0,因为结果完全确定,没有任何不确定性。
二、交叉熵:近似分布下的不确定性测度
(一)交叉熵的定义与作用
当我们尝试使用另一个概率分布 Q Q Q 来近似真实的概率分布 P P P 时,交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 便成为了衡量这种近似效果的重要工具。其计算公式为: H ( P , Q ) = − ∑ i P ( x i ) log Q ( x i ) H(P, Q)=-\sum_{i}P(x_i)\log Q(x_i) H(P,Q)=−i∑P(xi)logQ(xi)交叉熵的意义在于,它反映了在采用分布 Q Q Q 对源于分布 P P P 的事件进行编码、预测或描述时,平均所需要的信息量。
由于熵表征不确定性大小,且基于真实分布 P P P 本身编码是最“有效”的方式(即不确定性最小),所以当使用其他分布 Q Q Q 来近似 P P P 进行编码时,必然会引入更多的不确定性,也就意味着交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 肯定会比熵 H ( P ) H(P) H(P) 大。
例如,假设我们正在预测明天的天气状况,真实的天气概率分布 P P P 为晴天 60 % 60\% 60%、多云 30 % 30\% 30%、下雨 10 % 10\% 10%。然而,由于某些原因,我们错误地认为概率分布 Q Q Q 是晴天 30 % 30\% 30%、多云 30 % 30\% 30%、下雨 40 % 40\% 40%。在这种情况下,基于错误的分布 Q Q Q 来预测明天天气,相较于基于真实分布 P P P 预测,必然会带来更多的不确定性,即交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 大于熵 H ( P ) H(P) H(P)。这表明使用不恰当的分布 Q Q Q 进行预测时,我们对预测结果更加不确定,需要更多的信息量来描述这种预测情况。
(二)交叉熵与信息熵的关系
交叉熵与信息熵密切相关,信息熵 H ( P ) H(P) H(P) 可以视为交叉熵 H ( P , P ) H(P, P) H(P,P) 的特殊情况,即当我们使用真实分布 P P P 自身来对事件进行编码或预测时的平均信息量。而交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 则衡量了使用近似分布 Q Q Q 替代真实分布 P P P 时,信息量的变化情况。由于熵代表了基于真实分布的最小不确定性,所以交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 总是大于等于 H ( P ) H(P) H(P)。
(三)举个🌰
二分类例子
假设我们要判断一封邮件是垃圾邮件还是正常邮件。真实情况中,邮件是垃圾邮件的概率为 P ( 垃圾邮件 ) = 0.8 P(\text{垃圾邮件}) = 0.8 P(垃圾邮件)=0.8,是正常邮件的概率为 P ( 正常邮件 ) = 0.2 P(\text{正常邮件}) = 0.2 P(正常邮件)=0.2,即真实分布 P P P 为: P = [ 0.8 , 0.2 ] P = [0.8, 0.2] P=[0.8,0.2]。
某邮件分类模型对邮件类别的预测概率分布 Q Q Q 为:认为是垃圾邮件的概率 Q ( 垃圾邮件 ) = 0.6 Q(\text{垃圾邮件}) = 0.6 Q(垃圾邮件)=0.6,是正常邮件的概率 Q ( 正常邮件 ) = 0.4 Q(\text{正常邮件}) = 0.4 Q(正常邮件)=0.4,即 Q = [ 0.6 , 0.4 ] Q = [0.6, 0.4] Q=[0.6,0.4]。
根据交叉熵公式 H ( P , Q ) = − ∑ i P ( x i ) log Q ( x i ) H(P, Q)=-\sum_{i}P(x_i)\log Q(x_i) H(P,Q)=−i∑P(xi)logQ(xi),计算过程如下:
H ( P , Q ) = − ( P ( 垃圾邮件 ) log Q ( 垃圾邮件 ) + P ( 正常邮件 ) log Q ( 正常邮件 ) ) = − ( 0.8 × log ( 0.6 ) + 0.2 × log ( 0.4 ) ) \begin{align*} H(P, Q)&= - (P(\text{垃圾邮件})\log Q(\text{垃圾邮件}) + P(\text{正常邮件})\log Q(\text{正常邮件}))\\ &= - (0.8\times\log(0.6) + 0.2\times\log(0.4)) \end{align*} H(P,Q)=−(P(垃圾邮件)logQ(垃圾邮件)+P(正常邮件)logQ(正常邮件))=−(0.8×log(0.6)+0.2×log(0.4))