PyTorch图像分类损失函数全解析

在PyTorch中,图像分类任务常用的损失函数主要围绕交叉熵(Cross Entropy及其变体展开,适用于单标签或多标签分类场景。以下是PyTorch内置的图像分类相关损失函数的详细总结,涵盖适用场景、输入要求、关键参数及示例代码。

一、核心损失函数

1. CrossEntropyLoss(交叉熵损失)

适用场景:单标签多分类任务(每个样本仅属于一个类别)。
原理:结合了 LogSoftmaxNLLLoss(负对数似然损失),直接对模型输出的未归一化 logits 计算损失,无需手动添加激活函数。

输入要求

  • input:形状为 (N, C)(N, C, d1, d2, ..., dK) 的张量,其中 N 是批量大小,C 是类别数,d1...dK 是空间维度(如图像的高、宽)。
  • target:形状为 (N,)(N, d1, d2, ..., dK) 的长整型张量,表示每个样本的真实类别索引(范围 [0, C-1])。

关键参数

  • weight:形状为 (C,) 的张量,用于为每个类别分配权重(解决类别不平衡问题)。
  • reduction:损失归约方式,可选 'mean'(默认,平均)、'sum'(求和)、'none'(不归约,返回每个样本的损失)。
  • ignore_index:忽略某个类别(不计算其损失),设为 -1 时不过滤。

示例代码

import torch
import torch.nn as nn

# 模型输出 logits(未归一化)
logits = torch.randn(3, 5)  # 3个样本,5个类别
target = torch.tensor([1, 3, 0])  # 真实类别索引

# 初始化损失函数(带类别权重)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.5, 0.2, 0.3, 0.9]))
loss = criterion(logits, target)
print(loss)  # 输出标量损失
2. NLLLoss(负对数似然损失)

适用场景:多分类任务(需配合 LogSoftmax 层输出对数概率)。
原理:直接计算真实类别对应预测概率的负对数,要求模型输出已通过 LogSoftmax 归一化的对数概率。

输入要求

  • input:形状同 CrossEntropyLoss,但值为对数概率(log_softmax 输出)。
  • target:形状同 CrossEntropyLoss,为类别索引。

关键参数:与 CrossEntropyLoss 一致(weight, reduction, ignore_index)。

示例代码

# 模型输出 log_softmax(已归一化)
log_probs = torch.log_softmax(torch.randn(3, 5), dim=1)  # 对每行做 log_softmax
target = torch.tensor([1, 3, 0])

criterion = nn.NLLLoss(weight=torch.tensor([0.1, 0.5, 0.2, 0.3, 0.9]))
loss = criterion(log_probs, target)
print(loss)
3. KLDivLoss(KL散度损失)

适用场景:分布匹配任务(如知识蒸馏、生成模型中的教师-学生模型对齐)。
原理:计算预测分布与真实分布的KL散度(KL(p||q) = Σ p(x) * log(p(x)/q(x))),要求预测分布是归一化的概率,真实分布通常为one-hot向量。

输入要求

  • input:形状同前,值为预测分布的对数概率(log_softmax 输出)。
  • target:形状为 (N, C)(N, C, d1, ...dK) 的浮点型张量,表示真实分布的概率(如one-hot向量)。

关键参数

  • reduction:归约方式('mean' 默认,'batchmean' 按批量平均,'sum' 求和)。

示例代码(知识蒸馏场景):

# 教师模型输出(真实分布)
teacher_logits = torch.randn(3, 5)
teacher_probs = torch.softmax(teacher_logits, dim=1)

# 学生模型输出(log_softmax)
student_logits = torch.randn(3, 5)
student_log_probs = torch.log_softmax(student_logits, dim=1)

criterion = nn.KLDivLoss(reduction='batchmean')
loss = criterion(student_log_probs, teacher_probs)
print(loss)

二、多标签分类损失函数

4. BCEWithLogitsLoss(二值交叉熵 + Sigmoid)

适用场景:多标签二分类任务(每个样本可属于多个类别,每个类别独立判断为0/1)。
原理:将 Sigmoid 激活函数与二值交叉熵(BCE)合并,避免数值不稳定(直接对 logits 计算,而非概率)。

输入要求

  • input:形状为 (N, C)(N, C, d1, ...dK) 的张量,表示每个类别的 logits(未经过 Sigmoid)。
  • target:形状同 input 的浮点型张量,取值为 01(表示每个类别的真实标签)。

关键参数

  • weight:形状为 (C,) 的张量,为每个类别分配权重(解决类别不平衡)。
  • reduction:归约方式('mean' 默认)。
  • pos_weight:正类样本的权重(用于平衡正负样本,公式:loss = - (pos_weight * y * log(p) + (1-y) * log(1-p)))。

示例代码

# 模型输出 logits(多标签,每个类别独立)
logits = torch.randn(3, 5)  # 3个样本,5个标签
target = torch.tensor([[1, 0, 1, 0, 1], [0, 1, 0, 1, 0], [1, 1, 0, 0, 0]]).float()  # 多标签真实值

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]))  # 正类权重
loss = criterion(logits, target)
print(loss)
5. BCELoss(二值交叉熵)

适用场景:多标签二分类(需手动对 logits 应用 Sigmoid 激活)。
原理:直接计算每个类别的二值交叉熵,要求模型输出已通过 Sigmoid 归一化的概率(范围 [0,1])。

输入要求

  • input:形状同 BCEWithLogitsLoss,值为概率(sigmoid 输出)。
  • target:形状同 input,取值为 01

关键参数:与 BCEWithLogitsLoss 类似(weight, reduction)。

示例代码

# 模型输出 sigmoid 概率
probs = torch.sigmoid(torch.randn(3, 5))  # 概率范围 [0,1]
target = torch.tensor([[1, 0, 1, 0, 1], [0, 1, 0, 1, 0], [1, 1, 0, 0, 0]]).float()

criterion = nn.BCELoss(weight=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]))
loss = criterion(probs, target)
print(loss)

三、其他损失函数(可选)

6. HingeEmbeddingLoss(铰链嵌入损失)

适用场景:类似SVM的分类任务(通过嵌入向量的间隔最大化区分类别)。
原理:鼓励相似样本的嵌入向量具有大的正得分,不相似样本具有小的负得分(公式:loss(x, y) = max(0, margin - y*x),其中 y ∈ {1, -1})。

输入要求

  • input:形状为 (N, D)(N, D, d1, ...dK) 的张量,表示嵌入向量。
  • target:形状为 (N,)(N, d1, ...dK) 的长整型张量,取值为 1(相似)或 -1(不相似)。

关键参数margin(间隔,默认 1.0),reduction(归约方式)。

示例代码

embeddings = torch.randn(3, 10)  # 3个样本,10维嵌入
targets = torch.tensor([1, -1, 1])  # 相似/不相似标签

criterion = nn.HingeEmbeddingLoss(margin=1.0)
loss = criterion(embeddings, targets)
print(loss)

总结:选择建议

  • 单标签多分类:优先 CrossEntropyLoss(无需手动激活,高效)。
  • 多标签分类:使用 BCEWithLogitsLoss(自动融合 Sigmoid 和 BCE,数值稳定)。
  • 分布匹配(如蒸馏)KLDivLoss(需模型输出 log_softmax)。
  • 类别不平衡:通过 weight 参数(CrossEntropyLoss/BCEWithLogitsLoss)或调整样本权重。
  • SVM-like分类HingeEmbeddingLoss(需自定义嵌入层)。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值