Softmax函数下的交叉熵损失含义与求导

本文深入探讨了softmax函数及其在分类任务中的应用,以及交叉熵损失函数的意义。通过分析交叉熵为何能衡量分类效果,解释了其与熵、自信息和Kullback-Leibler Divergence的关系。同时,介绍了损失函数的梯度计算,强调了交叉熵在优化过程中评估模型与真实分布差异的重要性。

交叉熵损失函数(CrossEntropy Function)是分类任务中十分常用的损失函数,但若仅仅看它的形式,我们不容易直接靠直觉来感受它的正确性,因此我查阅资料写下本文,以求彻底搞懂。

1.Softmax

首先是我们的softmax函数。
它很简单,以一个向量作为输入,把向量的每个分量,用指数函数归一化后输出。具体来说,其数学形式为:
softmax(xi)=exi∑kexksoftmax(x_i) = \frac{e^{x_i}}{\sum_ke^{x_k}}softmax(xi)=kexkexi
xix_ixi为向量中第i个项。

softmaxsoftmaxsoftmax的输出向量为input_vectorinput\_vectorinput_vector,当input_vectorinput\_vectorinput_vector中某个分量过大时,可能导致其指数形式数值过大溢出。因此我们需要将input_vectorinput\_vectorinput_vector每一项减去max(input_vector)max(input\_vector)max(input_vector),再输入softmaxsoftmaxsoftmax中去。可以验证,减去最大值后,softmaxsoftmaxsoftmax的输出不会改变。

2.交叉熵损失函数

交叉熵损失函数作为分类任务最常用的loss functionloss\ functionloss function,我们理应深刻理解并熟知其形式与含义。

2.1 交叉熵为什么能衡量分类结果的好坏?

这是交叉熵损失函数的公式:
Loss=−∑iyilogaiLoss = -\sum_iy_iloga_iLoss=iyilogai
其中,aia_iai为softmax函数输出向量的第iii个分量,yiy_iyi为样本真实标签的one hotone\ hotone hot形式的第i个分量。
也就是说,实际上求和的项中只有一个不为0,LossLossLoss就等于−logak-loga_klogak,其中k对应的是样本的真实标签。
在这里插入图片描述

根据对数函数的性质,容易得出−logak-loga_klogakaka_kak增大而减小。也就是说,softmaxsoftmaxsoftmax输出向量对应labellabellabel的那项越大,LossLossLoss就越小,这正是我们希望的。

*2.2 熵、交叉熵(选读)

2.2.1 自信息

自信息是用来量化信息量大小的值。可以这样认为,越让我们觉得惊讶的事情带来的信息量越大。
一个优生考出好成绩的概率很大,当他期末成绩优秀时我们不会感到奇怪。但当一个学习很差的同学考得很好,我们就会感到非常惊讶。前者发生时,我们难以挖掘出隐含信息,而后者发生时,我们会推测这名学生可能有作弊行为,抑或他近期学习用功,找到了高效的学习方法。这样来看,后者带来的信息量明显大于前者。也就是说,概率更小的事情的发生能带来更多的信息。自然地,我们可以用概率来衡量自信息,公式如下:
s=log(1/pi)=−logpis = log(1/p_i)=-logp_is=log(1/pi)=logpi
其中pip_ipi是事件i发生的概率。

2.2.2 熵

熵,即为包含信息量的大小。对于一个事件(如分类),它的结果有很多种(如分为狗、猫…),分别对应不同的概率。很自然的,我们用各个结果自信息的加权平均作为事件的熵。
假设有n个结果(如分类有n类),则可以写成:
e=−∑ipilogpie = -\sum_i{p_ilogp_i}e=ipilogpi

2.2.3 交叉熵

交叉熵和熵是类似的概念,区别在于,现在我们是对两个不同分布进行定义。离散型变量的交叉熵可定义为:
H=−∑pilogqiH =-\sum{p_ilogq_i}H=pilogqi
可以理解为,每个结果的实际概率为pip_ipi,却有人将概率估计为qiq_iqi。(当前样本分为类i的真正概率为pip_ipi,但是分类器认为该样本为第i类的概率是qiq_iqi)。也就是说,我们带着某个主观认知去接触某个客观随机现象的时候,会产生的平均自信息量。
当我们主观上认为一个事情发生的概率很低(即−logqi(x)-logq_i(x)logqi(x)很大),但是客观上发生概率pip_ipi很大的时候,也就是主观认知和客观现实非常不匹配的时候,交叉熵的结果会很大。

交叉熵衡量了两个概率分布的差异。其值越大,两个概率分布相差越大;其值越小,则两个概率分布的差异越小。

从概率论角度来讲,我们要让分类器结果好,就是要让输出的各类的概率分布与真实分布接近,也就是要优化交叉熵,让其尽量小。

2.2.4 Kullback-Leibler Divergence(K-L 散度)

交叉熵可以衡量我们基于某种主观认识去感受客观世界时,会产生的平均自信息量。但是根据上面的公式,即使主观和客观完全匹配(交叉熵等于信息熵),只要事件仍然随机而非确定,产生的自信息量就一定大于0。那我们应该如何更好地度量主观认识和客观事实之间差异呢?可以用当前对事件的主观认识产生的期望和完全正确认识事件时产生的期望的差值来衡量,也就是相对熵(常称作KL-散度),通常写作:
在这里插入图片描述
当我们的主观认知完全匹配客观现实的时候,KL-散度应该等于0,其它任何时候都会大于0。

由于存在恒为正这一性质,KL-散度经常用于描述两个分布是否接近,也就是作为两个分布之间“距离”的度量;不过由于运算不满足交换律,所以有不能完全等同于“距离”来理解。

机器学习中通常直接用交叉熵作为损失函数的原因在于,客观分布并不随模型参数变化而变化,所以即使是优化KL-散度,对参数求导后就只剩下交叉熵的导数了。

3. 交叉熵损失求导(梯度计算)

这一部分先鸽在这,后面ipad笔修好了拿手推一下。

<think>我们被要求解释为什么在神经网络中softmax函数通常交叉熵损失函数配合使用。结合用户提供的引用内容,我们可以从数学原理和实际计算两个角度来分析。 ### 1. **数学推导的简化** - 在分类问题中,我们通常用softmax函数将神经网络的原始输出(即logits)转换为概率分布。假设有$C$类,对于输入$x$,网络输出向量$z = (z_1, z_2, \dots, z_C)$,则softmax函数定义为: $$ \hat{y}_i = \frac{e^{z_i}}{\sum_{j=1}^{C} e^{z_j}} \quad \text{for } i=1,2,\dots,C $$ 其中$\hat{y}_i$表示预测为第$i$类的概率。 - 交叉熵损失函数(又称负对数似然)用于衡量预测概率分布$\hat{y}$真实标签$y$(通常为one-hot编码)之间的差异: $$ L = -\sum_{i=1}^{C} y_i \log \hat{y}_i $$ 由于$y$是one-hot向量(仅有一个元素为1,其余为0),实际计算时只需考虑真实类别的概率:$L = -\log \hat{y}_k$(其中$k$为真实类别索引)。 - **关键点**:当使用交叉熵损失函数softmax组合时,求导结果异常简洁。具体推导如下: 令$L = -\log \hat{y}_k$,且$\hat{y}_k = \frac{e^{z_k}}{\sum_j e^{z_j}}$。则: $$ \frac{\partial L}{\partial z_i} = \hat{y}_i - y_i $$ 其中$y_i$是真实标签的one-hot表示。这个求导结果非常简洁,避免了复杂的链式求导过程[^1][^2]。 ### 2. **数值稳定性** - 直接计算softmax时,指数运算可能导致数值溢出(特别是当$z_i$很大时)。但在计算交叉熵损失时,我们可以将softmax交叉熵合并为一个整体计算,从而避免中间步骤的数值不稳定问题。实际实现时通常采用以下技巧: $$ \log \hat{y}_k = z_k - \log \sum_j e^{z_j} $$ 通过减去最大值$m = \max(z_j)$来保证数值稳定: $$ \hat{y}_k = \frac{e^{z_k - m}}{\sum_j e^{z_j - m}} $$ 这样指数部分的值不会过大,避免了溢出问题[^2][^3]。 ### 3. **梯度传播效率** - 如前所述,梯度计算简化为$\frac{\partial L}{\partial z_i} = \hat{y}_i - y_i$。这意味着: - 梯度仅取决于当前预测概率真实标签的差异。 - 在反向传播时,梯度可以直接从输出层传递到前一层的权重,无需计算softmax的中间梯度,从而提高了计算效率。 ### 4. **理论一致性** - 交叉熵损失函数本质上是衡量两个概率分布(真实分布$y$预测分布$\hat{y}$)的差异。而softmax输出的恰好是一个合法的概率分布(和为1且非负),因此两者在数学形式上天然匹配[^1]。 ### 代码示例说明 引用[2][^2]和引用[3][^3]中的代码展示了softmax的基本实现,但未涉及交叉熵的结合。实际PyTorch中通常这样实现: ```python import torch import torch.nn as nn # 定义网络输出(logits) logits = torch.tensor([[2.0, 1.0, 0.1]]) # 1个样本,3个类别 # 真实标签(假设为第0类) true_labels = torch.tensor([0]) # 组合计算 criterion = nn.CrossEntropyLoss() # 内部已包含softmax loss = criterion(logits, true_labels) # 手动计算验证 softmax = nn.Softmax(dim=1) probs = softmax(logits) manual_loss = -torch.log(probs[0, 0]) # 对应真实类别的概率 print(f"PyTorch loss: {loss.item()}, Manual loss: {manual_loss.item()}") ``` ### 总结原因 1. **梯度计算高效**:$\frac{\partial L}{\partial z_i} = \hat{y}_i - y_i$ 形式简单,避免梯度弥散。 2. **数值稳定性**:合并计算可防止指数溢出。 3. **理论匹配**:交叉熵衡量概率分布差异,softmax输出概率分布。 §§相关问题§§ 1. 为什么在二分类问题中sigmoid+BCEsoftmax+cross_entropy等价? 2. 如何处理多标签分类(multi-label classification)的损失函数? 3. 交叉熵损失函数KL散度有何联系? </think> ### Softmax函数配合交叉熵损失函数的原因分析 #### 1. **梯度计算的优化(核心优势)** - **数学简化**:当softmax交叉熵结合时,损失函数对网络输出$z_i$的梯度简化为: $$ \frac{\partial L}{\partial z_i} = \hat{y}_i - y_i $$ 其中$\hat{y}_i$是softmax输出的概率,$y_i$是真实标签(one-hot编码)。这种形式: - 只涉及**概率差值计算**,无复杂导数链式法则 - **计算效率高**(仅需减法运算) - **避免梯度弥散**(梯度值直接误差相关) - **对比其他损失函数**: - 若使用均方误差(MSE),梯度包含$\hat{y}_i(1-\hat{y}_i)$项,易导致梯度消失[^1] - 交叉熵梯度始终保持误差$\hat{y}_i - y_i$的线性关系 #### 2. **概率建模的匹配性** - **理论一致性**: - Softmax输出**归一化概率分布**:$\sum_i \hat{y}_i = 1$ - 交叉熵衡量**两个概率分布差异**:$L = -\sum y_i \log(\hat{y}_i)$ - 两者共同构建了概率框架下的优化目标(最大似然估计) - **信息论解释**: - 交叉熵等价于最小化预测分布真实分布的KL散度 - Softmax确保输出满足概率分布公理 #### 3. **数值稳定性** - **计算实践优势**: ```python # 实际实现合并计算(避免中间值溢出) logits = np.array([2.0, 1.0, 0.1]) max_logit = np.max(logits) stable_probs = np.exp(logits - max_logit) / np.sum(np.exp(logits - max_logit)) # 引用[2][^2] ``` - 联合计算可通过**Log-Sum-Exp技巧**提升数值稳定性 - 单独计算softmax再求交叉熵可能导致中间值溢出 #### 4. **多分类场景的普适性** - **统一处理能力**: - 二分类是softmax的特例(sigmoid等价于两类softmax) - 可直接扩展到任意多分类任务 - **决策边界优化**: - Softmax的指数形式天然强化最大类别概率 - 交叉熵推动模型对正确类别的置信度提升 ### 代码实现对比 ```python # 非优化实现(数值不稳定) def unstable_loss(logits, y_true): probs = np.exp(logits) / np.sum(np.exp(logits)) # 引用[3] return -np.sum(y_true * np.log(probs)) # 优化实现(PyTorch风格合并计算) def cross_entropy(logits, y_true): max_logits = np.max(logits, axis=-1, keepdims=True) stable_logits = logits - max_logits log_probs = stable_logits - np.log(np.sum(np.exp(stable_logits), axis=-1, keepdims=True)) return -np.sum(y_true * log_probs) ``` ### 总结优势 | **方面** | **优势表现** | |-------------------|-----------------------------------------------------------------------------| | **梯度质量** | 梯度简洁($\hat{y}_i - y_i$),无饱和现象,反向传播高效[^1] | | **计算效率** | 仅需加减运算,无指数/对数嵌套求导 | | **概率解释性** | 严格匹配概率建模框架 | | **数值稳定性** | 可通过联合计算避免浮点溢出 | | **任务普适性** | 自然支持二分类/多分类,无需修改损失函数 |
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值