在PyTorch中,图像分类任务常用的损失函数主要围绕交叉熵(Cross Entropy及其变体展开,适用于单标签或多标签分类场景。以下是PyTorch内置的图像分类相关损失函数的详细总结,涵盖适用场景、输入要求、关键参数及示例代码。
一、核心损失函数
1. CrossEntropyLoss
(交叉熵损失)
适用场景:单标签多分类任务(每个样本仅属于一个类别)。
原理:结合了 LogSoftmax
和 NLLLoss
(负对数似然损失),直接对模型输出的未归一化 logits 计算损失,无需手动添加激活函数。
输入要求:
input
:形状为(N, C)
或(N, C, d1, d2, ..., dK)
的张量,其中N
是批量大小,C
是类别数,d1...dK
是空间维度(如图像的高、宽)。target
:形状为(N,)
或(N, d1, d2, ..., dK)
的长整型张量,表示每个样本的真实类别索引(范围[0, C-1]
)。
关键参数:
weight
:形状为(C,)
的张量,用于为每个类别分配权重(解决类别不平衡问题)。reduction
:损失归约方式,可选'mean'
(默认,平均)、'sum'
(求和)、'none'
(不归约,返回每个样本的损失)。ignore_index
:忽略某个类别(不计算其损失),设为-1
时不过滤。
示例代码:
import torch
import torch.nn as nn
# 模型输出 logits(未归一化)
logits = torch.randn(3, 5) # 3个样本,5个类别
target = torch.tensor([1, 3, 0]) # 真实类别索引
# 初始化损失函数(带类别权重)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.5, 0.2, 0.3, 0.9]))
loss = criterion(logits, target)
print(loss) # 输出标量损失
2. NLLLoss
(负对数似然损失)
适用场景:多分类任务(需配合 LogSoftmax
层输出对数概率)。
原理:直接计算真实类别对应预测概率的负对数,要求模型输出已通过 LogSoftmax
归一化的对数概率。
输入要求:
input
:形状同CrossEntropyLoss
,但值为对数概率(log_softmax
输出)。target
:形状同CrossEntropyLoss
,为类别索引。
关键参数:与 CrossEntropyLoss
一致(weight
, reduction
, ignore_index
)。
示例代码:
# 模型输出 log_softmax(已归一化)
log_probs = torch.log_softmax(torch.randn(3, 5), dim=1) # 对每行做 log_softmax
target = torch.tensor([1, 3, 0])
criterion = nn.NLLLoss(weight=torch.tensor([0.1, 0.5, 0.2, 0.3, 0.9]))
loss = criterion(log_probs, target)
print(loss)
3. KLDivLoss
(KL散度损失)
适用场景:分布匹配任务(如知识蒸馏、生成模型中的教师-学生模型对齐)。
原理:计算预测分布与真实分布的KL散度(KL(p||q) = Σ p(x) * log(p(x)/q(x))
),要求预测分布是归一化的概率,真实分布通常为one-hot向量。
输入要求:
input
:形状同前,值为预测分布的对数概率(log_softmax
输出)。target
:形状为(N, C)
或(N, C, d1, ...dK)
的浮点型张量,表示真实分布的概率(如one-hot向量)。
关键参数:
reduction
:归约方式('mean'
默认,'batchmean'
按批量平均,'sum'
求和)。
示例代码(知识蒸馏场景):
# 教师模型输出(真实分布)
teacher_logits = torch.randn(3, 5)
teacher_probs = torch.softmax(teacher_logits, dim=1)
# 学生模型输出(log_softmax)
student_logits = torch.randn(3, 5)
student_log_probs = torch.log_softmax(student_logits, dim=1)
criterion = nn.KLDivLoss(reduction='batchmean')
loss = criterion(student_log_probs, teacher_probs)
print(loss)
二、多标签分类损失函数
4. BCEWithLogitsLoss
(二值交叉熵 + Sigmoid)
适用场景:多标签二分类任务(每个样本可属于多个类别,每个类别独立判断为0/1)。
原理:将 Sigmoid 激活函数与二值交叉熵(BCE)合并,避免数值不稳定(直接对 logits 计算,而非概率)。
输入要求:
input
:形状为(N, C)
或(N, C, d1, ...dK)
的张量,表示每个类别的 logits(未经过 Sigmoid)。target
:形状同input
的浮点型张量,取值为0
或1
(表示每个类别的真实标签)。
关键参数:
weight
:形状为(C,)
的张量,为每个类别分配权重(解决类别不平衡)。reduction
:归约方式('mean'
默认)。pos_weight
:正类样本的权重(用于平衡正负样本,公式:loss = - (pos_weight * y * log(p) + (1-y) * log(1-p))
)。
示例代码:
# 模型输出 logits(多标签,每个类别独立)
logits = torch.randn(3, 5) # 3个样本,5个标签
target = torch.tensor([[1, 0, 1, 0, 1], [0, 1, 0, 1, 0], [1, 1, 0, 0, 0]]).float() # 多标签真实值
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0])) # 正类权重
loss = criterion(logits, target)
print(loss)
5. BCELoss
(二值交叉熵)
适用场景:多标签二分类(需手动对 logits 应用 Sigmoid 激活)。
原理:直接计算每个类别的二值交叉熵,要求模型输出已通过 Sigmoid 归一化的概率(范围 [0,1]
)。
输入要求:
input
:形状同BCEWithLogitsLoss
,值为概率(sigmoid
输出)。target
:形状同input
,取值为0
或1
。
关键参数:与 BCEWithLogitsLoss
类似(weight
, reduction
)。
示例代码:
# 模型输出 sigmoid 概率
probs = torch.sigmoid(torch.randn(3, 5)) # 概率范围 [0,1]
target = torch.tensor([[1, 0, 1, 0, 1], [0, 1, 0, 1, 0], [1, 1, 0, 0, 0]]).float()
criterion = nn.BCELoss(weight=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]))
loss = criterion(probs, target)
print(loss)
三、其他损失函数(可选)
6. HingeEmbeddingLoss
(铰链嵌入损失)
适用场景:类似SVM的分类任务(通过嵌入向量的间隔最大化区分类别)。
原理:鼓励相似样本的嵌入向量具有大的正得分,不相似样本具有小的负得分(公式:loss(x, y) = max(0, margin - y*x)
,其中 y ∈ {1, -1}
)。
输入要求:
input
:形状为(N, D)
或(N, D, d1, ...dK)
的张量,表示嵌入向量。target
:形状为(N,)
或(N, d1, ...dK)
的长整型张量,取值为1
(相似)或-1
(不相似)。
关键参数:margin
(间隔,默认 1.0
),reduction
(归约方式)。
示例代码:
embeddings = torch.randn(3, 10) # 3个样本,10维嵌入
targets = torch.tensor([1, -1, 1]) # 相似/不相似标签
criterion = nn.HingeEmbeddingLoss(margin=1.0)
loss = criterion(embeddings, targets)
print(loss)
总结:选择建议
- 单标签多分类:优先
CrossEntropyLoss
(无需手动激活,高效)。 - 多标签分类:使用
BCEWithLogitsLoss
(自动融合 Sigmoid 和 BCE,数值稳定)。 - 分布匹配(如蒸馏):
KLDivLoss
(需模型输出log_softmax
)。 - 类别不平衡:通过
weight
参数(CrossEntropyLoss
/BCEWithLogitsLoss
)或调整样本权重。 - SVM-like分类:
HingeEmbeddingLoss
(需自定义嵌入层)。