【Python深度学习】——交叉熵|KL散度|交叉熵损失函数

1. 交叉熵Cross Entropy

1.1 交叉熵的含义

交叉熵是两个概率分布之间差异的一种度量.
在机器学习中被广泛应用,尤其是分类问题中。

1.2 交叉熵的公式

H ( P , Q ) = − ∑ i P ( x i ) log ⁡ Q ( x i ) H(P, Q) = -\sum_{i} P(x_i) \log Q(x_i) H(P,Q)=iP(xi)logQ(xi)
其中, P ( x i ) P(x_i) P(xi)是真实分布中第 ( i ) 个事件的概率;
Q ( x i ) Q(x_i) Q(xi)是模型预测分布中第 ( i ) 个事件的概率。

1.3 交叉熵的特点

  1. 当模型的预测分布 𝑄 越接近真实分布 𝑃 时,交叉熵值越小。
  2. 交叉熵的值总是大于系统真实熵的值。

2. KL散度

2.1 KL散度的含义

KL散度也称为相对熵, 它也是度量两个概率分布之间差异的一种方式. 它是以Kullback和Leibler两个人命名的.

2.2 KL散度的公式

假设有两个概率分布 𝑃 和 𝑄,其中 𝑃 是真实分布,𝑄 是模型的预测分布。KL散度表示二者的交叉熵的差:
D K L ( P ∥ Q ) = H ( P , Q ) − H ( P ) D_{KL}(P \parallel Q) = H(P, Q) - H(P) DKL(PQ)=H(P,

### 知识蒸馏中交叉熵KL的区别及应用场景 在知识蒸馏的过程中,交叉熵KL是两种常用的损失函数。它们的使用方式和适用场景有所不同,具体如下: #### 1. KL的定义与应用 KL(Kullback-Leibler Divergence)用于衡量两个概率分布之间的差异。在知识蒸馏中,KL被用来量化学生模型的输出分布与教师模型的输出分布之间的差异[^1]。通过最小化KL,学生模型能够学习到教师模型的概率分布,从而获得更优的性能。 公式表示为: \[ D_{KL}(P || Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)} \] 其中 \(P\) 表示教师模型的输出分布,\(Q\) 表示学生模型的输出分布。 在实际实现中,`torch.nn.KLDivLoss` 是一个常用工具,用于计算两个分布之间的KL[^2]。由于KL直接比较了两个分布的差异,因此它特别适合用于知识蒸馏中的软标签训练。 #### 2. 交叉熵的定义与应用 交叉熵(Cross-Entropy)是一种衡量两个概率分布之间差异的指标,常用于分类任务中。其公式为: \[ H(P, Q) = -\sum_{x} P(x) \log Q(x) \] 在知识蒸馏中,交叉熵通常用于硬标签训练,即目标分布是一个one-hot向量[^3]。在这种情况下,交叉熵可以简化为对数损失(Log Loss),并直接用于优化模型参数。 #### 3. 交叉熵KL的区别 - **目标分布**:KL适用于软标签训练,目标分布是教师模型的输出概率分布;而交叉熵更适合硬标签训练,目标分布通常是one-hot编码。 - **计算复杂**:KL需要同时考虑目标分布和模型分布,计算相对复杂;交叉熵则仅需关注目标分布下的对数概率值,计算更为简单。 - **应用场景**:KL广泛应用于知识蒸馏中的软标签训练,特别是在深度学习模型(如DeepSeek等大型语言模型)中;交叉熵则更多用于传统的分类任务或硬标签训练场景。 #### 4. 实现代码示例 以下是一个简单的PyTorch代码示例,展示如何使用KL交叉熵进行知识蒸馏训练: ```python import torch import torch.nn as nn import torch.nn.functional as F # 假设教师模型和学生模型的输出 teacher_logits = torch.tensor([[3.0, 1.0, 0.2], [2.0, 5.0, 1.0]]) student_logits = torch.tensor([[2.0, 1.5, 0.3], [1.5, 4.0, 1.2]]) # 转换为概率分布 temperature = 2.0 teacher_probs = F.softmax(teacher_logits / temperature, dim=-1) student_probs = F.log_softmax(student_logits / temperature, dim=-1) # 计算KL损失 kld_loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ** 2) print(f"KL Divergence Loss: {kld_loss.item()}") # 计算交叉熵损失(假设硬标签为[0, 1]) hard_labels = torch.tensor([0, 1]) cross_entropy_loss = nn.CrossEntropyLoss()(student_logits, hard_labels) print(f"Cross Entropy Loss: {cross_entropy_loss.item()}") ``` ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

steptoward

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值