分类问题为什么要用交叉熵

本文探讨了在分类问题中使用交叉熵损失函数而非均方误差的原因。信息熵和相对熵(KL散度)的概念被引入,解释了交叉熵是衡量概率分布差异的有效方式。由于信息熵在实际应用中为固定值,优化相对熵等同于优化交叉熵。均方误差在梯度下降时,由于sigmoid导数的影响,可能导致参数更新缓慢,而交叉熵则避免了这个问题,因此在分类任务中更优。

一提到分类,大家想到的损失函数就是交叉熵,但是有没有想过为什么分类问题要用交叉熵损失,为什么不用均方误差损失呢?本文将详细介绍交叉熵的由来,并分析为什么不使用均方误差。

🤗欢迎关注公众号 funNLPer🤗

1. 信息熵

信息熵就是信息的不确定程度,信息熵越小,信息越确定
信息熵 = ∑ 事件 x 发生的概率 ∗ 验证事件 x 需要的信息量 信息熵=\sum 事件x发生的概率*验证事件x需要的信息量 信息熵=事件x发生的概率验证事件x需要的信息量

事件发生的概率越低,需要越多的信息去验证,所以验证真假需要的信息量和事件发生的概率成反比,假设信息量为 I ( x ) I(x) I(x)
I ( x ) = − l o g   p ( x ) I(x) = -log\, p(x) I(x)=logp(x)

其中负号是用来保证信息量是正数或者零, p ( x ) p(x) p(x)是事件 x x x发生的概率, I ( x ) I(x) I(x) 也被称为随机变量 $x $的自信息 (self-information),描述的是随机变量的某个事件发生所带来的信息量

信息熵即所有信息量的期望
H ( X ) = − ∑ x p ( x ) log ⁡ ( p ( x ) ) = − ∑ i = 1 n p ( x i ) log ⁡ ( p ( x i ) ) H(X)=-\sum_{x} p(x) \log (p(x))=-\sum_{i=1}^{n} p\left(x_{i}\right) \log \left(p\left(x_{i}\right)\right) H(X)=xp(x)log(p(x))=i=1np(xi)log(p(xi))

其中 n n n 为事件的所有可能性

2. 相对熵(KL散度)

相对熵又称KL散度,如果对于同一个随机变量 x x x有两个单独的概率分布 p ( x ) p(x) p(x) q ( x ) q(x) q(x),可以使用相对熵来衡量这两个分布的差异

D K L ( p ∥ q ) = ∑ i = 1 n p ( x i ) log ⁡ ( p ( x i ) q ( x i ) ) D_{K L}(p \| q)=\sum_{i=1}^{n} p\left(x_{i}\right) \log \left(\frac{p\left(x_{i}\right)}{q\left(x_{i}\right)}\right) DKL(pq)=i=1np(xi)log(

### 交叉熵损失函数在分类问题中的应用 #### 数学原理 交叉熵损失函数的核心在于衡量两个概率分布之间的差异。对于分类问题,目标是使模型预测的概率分布 \( q \) 尽可能接近真实标签的概率分布 \( p \)[^2]。假设在一个具有 \( K \) 类别的多分类问题中,样本的真实标签为独热编码形式的一维向量 \( y_i \),而模型的输出是一个经过 softmax 函数处理后的概率分布 \( \hat{y}_i \)。 交叉熵损失函数定义如下: \[ L_{CE} = -\sum_{k=1}^{K} y_k \log(\hat{y}_k) \] 其中: - \( y_k \in \{0, 1\} \) 是第 \( k \) 类的真实标签; - \( \hat{y}_k \) 是模型对第 \( k \) 类的预测概率。 通过最大化 \( \hat{y}_k \) 对应于真实类别的值来最小化该损失函数[^1]。 #### 原因分析 选择交叉熵作为分类问题的损失函数有以下几个主要原因: 1. **直接优化目标** 分类问题的目标是最小化预测分布与真实分布之间的差距。由于交叉熵能够量化这一差距,因此它是自然的选择[^2]。 2. **数学性质优良** 交叉熵损失函数具备良好的凸性和可导性,在训练过程中可以通过梯度下降法高效求解。 3. **信息论解释** 在信息论框架下,交叉熵反映了传输真实分布所需的信息量大小。当预测分布越接近真实分布时,所需的额外信息量就越少,从而达到最优状态[^2]。 4. **与其他激活函数兼容性强** 结合 softmax 激活函数使用时,交叉熵损失能有效避免数值不稳定问题,并加速收敛速度[^1]。 以下是基于 PyTorch 实现的一个简单例子,展示如何计算交叉熵损失: ```python import torch import torch.nn as nn # 定义输入张量 (batch_size, num_classes) predictions = torch.tensor([[0.25, 0.75], [0.9, 0.1]], requires_grad=True) # 定义真实标签 (batch_size,) labels = torch.tensor([1, 0]) # 初始化交叉熵损失函数 criterion = nn.CrossEntropyLoss() # 计算损失 loss = criterion(predictions, labels) print(f'Cross entropy loss: {loss.item()}') ``` --- ###
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值