本节目标
- 理解损失函数的本质作用
- 掌握分类,回归,分割,自监督等任务中常见的损失函数类型
- 理解每类损失函数的公式,含义,适用场景,优缺点
- 学会在PyTorch中使用与定制损失函数
一、什么是损失函数?
损失函数是一个标量函数,用来衡量模型预测输出 与真实标签
之间的差距。
它是训练中唯一优化的目标,我们通过反向传播算法计算损失函数的梯度,从而更新网络参数。
二、损失函数的数学形式
通用形式:
预测值和标签值之间的损失值,损失值形式多样
目标:最小化整体损失
三、常见损失函数分类与详解
此处我们按照任务类型归类讲解:
1.分类任务常用损失函数
1.1 交叉熵损失(Cross Entropy Loss)
用于多类分类,衡量两个概率分布之间的差异。
公式(Softmax后):
说明:
-
:真实标签的one-hot编码
:Softmax输出的预测概率
适用任务:
- 多分类(单标签):如图像分类,情感分类
PyTorch示例:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, target) # output: [batch, num_classes]; target: [batch]
1.2 二分类交叉熵(Binary Cross Entropy)
用于二分类或多标签分类(每个标签看作独立二分类)
公式(Sigmoid后):
适用任务:多标签图像标注,多疾病诊断
PyTorch示例:
loss_fn = nn.BCELoss() # 预测值需经过 Sigmoid
或更推荐
loss_fn = nn.BCEWithLogitsLoss() # 自动融合 Sigmoid
2.回归任务常用损失函数
2.1 均方误差(MSE)
- 对大误差敏感(平方会放大差距)
- 收敛更稳定,导数连续平滑
2.2 平均绝对误差(MAE)
- 对异常值更鲁棒,但收敛速度慢
- 导数连续(在0点)
使用建议:
- 数据中异常值多➡MAE
- 平滑回归任务(如温度预测)➡MSE
nn.MSELoss()
nn.L1Loss() # MAE
3.分割任务常用损失函数
3.1 Dice Loss
用于衡量预测掩码与真实掩码的重叠程度
- 适用任务:医学图像分割,实例分割
- 优点:适应类别不平衡场景
3.2 IoU Loss(交并比)
与Dice类似,但更注重预测边界重合
4.自监督/对比学习任务
4.1 InfoNCE(对比损失)
:查询样本
:正样本
:负样本
:温度参数
- 适用场景:CLIP、SimCLR、MoCo、BERT
四、PyTorch中使用与自定义损失函数
常规用法:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(predictions, targets)
自定义损失函数:
class MyLoss(nn.Module):
def forward(self, y_pred, y_true):
return torch.mean(torch.abs(y_pred - y_true)) # MAE
五、不同损失函数的对比总结表
任务类型 | 损失函数 | 特点 | 适用 |
分类 | CrossEntropy | 对概率分布差异敏感 | 多分类 |
分类 | BCE/BCEWithLogits | 多个独立二分类 | 多分类 |
回归 | MSE | 平滑,对异常值敏感 | 普通回归 |
回归 | MAE | 对异常值鲁棒 | 金融预测 |
分割 | Dice Loss | 类别不平衡适配好 | 医学分割 |
自监督 | InfoNCE | 拉近正样本,推远负样本 | 表征学习 |