为什么分类模型使用交叉熵损失函数而不是均方误差?

在深度学习和机器学习领域,为模型选择正确的损失函数是至关重要的。对于回归问题,我们通常使用均方误差(Mean Squared Error, MSE)作为损失函数。然而,在处理分类问题时,交叉熵(Cross-Entropy, CE)损失函数却是更常见的选择。本篇文章将深入探讨为什么在分类任务中我们更倾向于使用交叉熵而不是均方误差。

1. 两种损失函数的简要回顾

在深入探讨原因之前,我们先来简单回顾一下这两种损失函数。

均方误差 (MSE)

MSE 计算的是预测值与真实值之间差值的平方的平均值。它通常用于回归问题,其中目标是预测一个连续值。
对于单个样本,MSE 的公式如下:
MSE=(ytrue−ypred)2MSE = (y_{true} - y_{pred})^2MSE=(ytrueypred)2

对于一个包含 N 个样本的数据集,MSE 的公式是:
MSE=1N∑i=1N(ytrue,i−ypred,i)2MSE = \frac{1}{N} \sum_{i=1}^{N} (y_{true, i} - y_{pred, i})^2MSE=N1i=1N(ytrue,iypred,i)2

其中 ytrue,iy_{true, i}ytrue,i 是第 i 个样本的真实值, ypred,iy_{pred, i}ypred,i 是模型的预测值。

交叉熵 (Cross-Entropy)

交叉熵源于信息论,它用于衡量两个概率分布之间的差异。在分类问题中,它衡量的是模型预测的概率分布与真实的概率分布之间的差异。

对于二分类问题,其损失函数(也称为二元交叉熵,Binary Cross-Entropy)通常与 SigmoidSigmoidSigmoid 激活函数结合使用,公式如下:
BCE=−(ytruelog⁡(ypred)+(1−ytrue)log⁡(1−ypred))BCE = - (y_{true} \log(y_{pred}) + (1 - y_{true}) \log(1 - y_{pred}))BCE=(ytruelog(ypred)+(1ytrue)log(1ypred))

对于多分类问题,其损失函数(也称为分类交叉熵,Categorical Cross-Entropy)通常与 SoftmaxSoftmaxSoftmax 激活函数结合使用,公式如下:
CCE=−∑j=1Cytrue,jlog⁡(ypred,j)CCE = - \sum_{j=1}^{C} y_{true, j} \log(y_{pred, j})CCE=j=1Cytrue,jlog(ypred,j)
其中 C 是类别的数量。

2. 为什么 MSE 不适合分类任务?

现在,让我们来探讨为什么 MSE 在分类任务中表现不佳。

输出的性质不匹配

分类模型的最终输出通常需要表示为概率。例如,在一个三分类问题中,对于一个输入样本,模型可能输出 [0.1,0.7,0.2][0.1, 0.7, 0.2][0.1,0.7,0.2],表示该样本属于第二类的概率是70%。这些概率值是通过 SigmoidSigmoidSigmoid(对于二分类)或 SoftmaxSoftmaxSoftmax(对于多分类)激活函数得到的,它们的输出范围在 [0, 1] 之间,并且对于 SoftmaxSoftmaxSoftmax 来说,所有类别的概率之和为1。

而 MSE 损失函数假设模型的输出是连续的,并且没有范围限制。直接将 MSE 应用于概率输出,其背后的假设(如误差服从高斯分布)与分类问题的本质(类别标签是离散的)是不符的。

梯度消失问题

这是使用 MSE 进行分类时最关键的问题之一。让我们以一个二分类问题为例,假设我们使用 SigmoidSigmoidSigmoid 作为激活函数,其输出 ypred=σ(z)y_{pred} = \sigma(z)ypred=σ(z),其中 zzz 是模型的线性输出。

如果我们使用 MSE,损失函数为 L=(ytrue−σ(z))2L = (y_{true} - \sigma(z))^2L=(ytrueσ(z))2
我们来计算损失函数对参数(例如 zzz)的梯度:
∂L∂z=∂L∂σ(z)∂σ(z)∂z=2(σ(z)−ytrue)⋅σ′(z)\frac{\partial L}{\partial z} = \frac{\partial L}{\partial \sigma(z)} \frac{\partial \sigma(z)}{\partial z} = 2(\sigma(z) - y_{true}) \cdot \sigma'(z)zL=σ(z)Lzσ(z)=2(σ(z)ytrue)σ(z)

SigmoidSigmoidSigmoid 函数的导数 σ′(z)=σ(z)(1−σ(z))\sigma'(z) = \sigma(z)(1-\sigma(z))σ(z)=σ(z)(1σ(z))。当 zzz 的绝对值很大时(即模型非常有信心地做出正确或错误的预测时),σ(z)\sigma(z)σ(z) 的值会非常接近 0 或 1。在这两种情况下,σ′(z)\sigma'(z)σ(z) 的值都非常接近 0。

  • 当预测非常错误时:例如,真实标签 ytrue=1y_{true}=1ytrue=1,但模型预测 ypred=σ(z)≈0y_{pred}=\sigma(z) \approx 0ypred=σ(z)0 (此时 zzz 是一个很大的负数)。我们希望此时有一个很大的梯度来快速纠正模型。但是,由于 σ′(z)≈0\sigma'(z) \approx 0σ(z)0,导致整个梯度 ∂L∂z\frac{\partial L}{\partial z}zL 也接近于 0。
  • 当预测非常正确时:例如,真实标签 ytrue=1y_{true}=1ytrue=1,模型预测 ypred=σ(z)≈1y_{pred}=\sigma(z) \approx 1ypred=σ(z)1 (此时 zzz 是一个很大的正数)。此时梯度也接近于0,这是我们期望的,因为模型已经表现很好了。

问题就在于,当模型犯下严重错误时,MSE+Sigmoid 的组合却提供了非常小的梯度,导致学习过程极其缓慢。这就好像一个学生考了0分,老师却只是轻轻地提醒他“下次努力”,而不是严厉地指出错误。

3. 为什么交叉熵是分类任务的更优选择?

现在我们来看看为什么交叉熵能够很好地解决上述问题。

与概率分布的本质契合

交叉熵天生就是用来衡量两个概率分布之间的距离的。在分类任务中,我们的目标就是让模型预测的概率分布尽可能地接近真实的概率分布(one-hot 编码)。因此,使用交叉熵作为损失函数在理论上是更加合理的。它源于最大似然估计(Maximum Likelihood Estimation),即我们希望找到一组模型参数,使得在给定输入数据下,观测到真实标签的概率最大。

强大的梯度,促进快速学习

让我们再次以二分类问题为例,这次使用交叉熵损失函数 L=−(ytruelog⁡(σ(z))+(1−ytrue)log⁡(1−σ(z)))L = - (y_{true} \log(\sigma(z)) + (1 - y_{true}) \log(1 - \sigma(z)))L=(ytruelog(σ(z))+(1ytrue)log(1σ(z)))
我们来计算损失函数对 zzz 的梯度:
∂L∂z=∂L∂σ(z)∂σ(z)∂z=−(ytrueσ(z)−1−ytrue1−σ(z))⋅σ′(z)\frac{\partial L}{\partial z} = \frac{\partial L}{\partial \sigma(z)} \frac{\partial \sigma(z)}{\partial z} = -(\frac{y_{true}}{\sigma(z)} - \frac{1-y_{true}}{1-\sigma(z)}) \cdot \sigma'(z)zL=σ(z)Lzσ(z)=(σ(z)ytrue1σ(z)1ytrue)σ(z)

代入 σ′(z)=σ(z)(1−σ(z))\sigma'(z) = \sigma(z)(1-\sigma(z))σ(z)=σ(z)(1σ(z)),经过化简,我们可以得到一个非常简洁和优美的结果:
∂L∂z=σ(z)−ytrue\frac{\partial L}{\partial z} = \sigma(z) - y_{true}zL=σ(z)ytrue

这个梯度形式非常直观:梯度的大小正比于预测值与真实值之间的差异

  • 当模型预测错误时(例如 ytrue=1y_{true}=1ytrue=1, σ(z)≈0\sigma(z) \approx 0σ(z)0),梯度 ≈−1\approx -11,这是一个很大的负梯度,会促使模型参数进行大幅度的更新。
  • 当模型预测正确时(例如 ytrue=1y_{true}=1ytrue=1, σ(z)≈1\sigma(z) \approx 1σ(z)1),梯度 ≈0\approx 00,模型趋于稳定。

这种特性使得模型在犯错时能够快速学习和纠正,大大提高了训练效率。

4. 一个直观的例子

假设我们正在训练一个模型来识别图片中的猫(ytrue=1y_{true}=1ytrue=1)。

  • 场景1:模型非常不确定
    模型预测图片是猫的概率为 0.5。

    • MSE 梯度2(0.5−1)⋅(0.5⋅0.5)=−0.252(0.5 - 1) \cdot (0.5 \cdot 0.5) = -0.252(0.51)(0.50.5)=0.25
    • CE 梯度0.5−1=−0.50.5 - 1 = -0.50.51=0.5

    在这个阶段,CE 的梯度比 MSE 的梯度更大,学习更快。

  • 场景2:模型非常有信心地犯错
    模型预测图片是猫的概率为 0.1 (即认为它不是猫)。

    • MSE 梯度2(0.1−1)⋅(0.1⋅0.9)=−0.1622(0.1 - 1) \cdot (0.1 \cdot 0.9) = -0.1622(0.11)(0.10.9)=0.162
    • CE 梯度0.1−1=−0.90.1 - 1 = -0.90.11=0.9

    可以看到,MSE 的梯度反而变小了!模型越错,梯度越小,学习越慢。而 CE 的梯度变得非常大,促使模型快速纠正这个严重的错误。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

布鲁格若门

对你有用的话真是太好了

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值