交叉熵的介绍
1、信息量
概率越小,信息量越大,事件X=x0X=x_0X=x0的信息量为:
I(x0)=−log(p(x0))I(x_0)=-log(p(x_0))I(x0)=−log(p(x0))
2、熵
熵表示所有信息量的期望:
H(x)=−∑i=1np(xi)log(p(xi))H(x)=-\sum_{i=1}^n p(x_i)log(p(x_i))H(x)=−i=1∑np(xi)log(p(xi))
其中n代表事件X有n种可能
3、相对熵(KL散度)
DKL(p∣∣q)=∑i=1np(xi)log(p(xi)q(xi))D_{KL}(p||q)=\sum_{i=1}^np(x_i)log(\frac{p(x_i)}{q(x_i)})DKL(p∣∣q)=i=1∑np(xi)log(q(xi)p(xi))
物理意义:如果用P来描述目标问题,而不是用Q来描述问题,得到的信息增量
在机器学习中,P往往表示样本的真实分布,q表示模型预测的分布,相对熵越小,表示q分布和p分布越接近
4、交叉熵 (越小越好)
相对熵可以变形为:
DKL(p∣∣q)=−H(p(x))+[−∑i=1np(xi)log(q(xi))]D_{KL}(p||q)=-H(p(x))+[-\sum_{i=1}^np(x_i)log(q(x_i))]DKL(p∣∣q)=−H(p(x))+[−i=1∑np(xi)log(q(xi))]
等式的前半部分是p的熵,后半部分就是交叉熵:
H(p,q)=−∑i=1np(xi)log(q(xi))H(p,q)=-\sum_{i=1}^np(x_i)log(q(x_i))H(p,q)=−i=1∑np(xi)log(q(xi))
在机器学习中,我们需要评估labl和predicts之间的差距,可以使用KL散度,但由于KL散度前半部分不变,故在优化过程中,只需要关注交叉熵就行,所以一般在机器学习中直接用交叉熵作为loss函数,评估模型。
分析
在回归问题中,常用MSE做loss函数,但在逻辑分类中却用不好,这时需要用交叉熵
1. 为什么MSE不适用于分类问题?
当sigmod函数和MSE一起使用时会出现梯度消失,原因是在使用MSE时,w、b的梯度均与sigmoid函数对z的偏导有关系,而sigmoid函数的偏导在自变量非常大或非常小是,偏导数的值接近与0,这将导致w、b的梯度将不会变化,也就是出现所谓的梯度消失现象。而使用交叉熵时,梯度就不会出现上述情况。所以MSE不适用于分类问题。
2. 交叉熵不适用于回归问题
当MSE和交叉熵同事应用于多分类场景下,MSE对每个输出结果都非常看重,而交叉熵只对正确分类的结果看重。可见MSE除了让正确的分类尽量变大,还会让错误的分类变得平均,这对回归问题显得很重要,所以MSE适合回归问题的loss函数。
机器学习中交叉熵的应用
1、交叉熵在单分类中的使用
- 这里的单类别指:每个样本只能有一个类别
- 交叉熵在单分类问题上的loss函数:
loss=−∑j=1m∑i=1nyjilog(y^ji)loss=-\sum_{j=1}^m\sum_{i=1}^ny_{ji}log(\hat{y}_{ji})loss=−j=1∑mi=1∑nyjilog(y^ji) - 这里的预测概率是通过softmax计算,概率合为1
2、交叉熵在多分类中使用
- 这里的多类别指:每个样本可以有多个类别
- 交叉熵在多分类问题上的loss问题:
loss=∑j=1m∑i=1n−yjilog(y^ji)−(1−yji)log(1−y^ji)loss=\sum_{j=1}^m\sum_{i=1}^n-y_{ji}log(\hat{y}_{ji})-(1-y_{ji})log(1-\hat{y}_{ji})loss=j=1∑mi=1∑n−yjilog(y^ji)−(1−yji)log(1−y^ji) - 这里的预测是通过sigmoid计算,每个label都是独立分布的,输出归一化
交叉熵在pytorch的使用
nn.CrossEntropyLoss()
这个损失函数和通常的交叉熵函数公式不一样,它是nn.logSoftmax()和nn.NLLLoss()的整合,公式如下:
loss(x, class )=−log(exp(x[ class ])∑jexp(x[j]))=−x[ class ]+log(∑jexp(x[j]))
\operatorname{loss}(x, \text { class })=-\log \left(\frac{\exp (x[\text { class }])}{\sum_{j} \exp (x[j])}\right)=-x[\text { class }]+\log \left(\sum_{j} \exp (x[j])\right)
loss(x, class )=−log(∑jexp(x[j])exp(x[ class ]))=−x[ class ]+log(j∑exp(x[j]))