文章目录
摘要
本文总结了 PyTorch 中常见的损失函数与相似度度量的理论基础和实现方法,重点分析了交叉熵、信息熵、负对数似然(NLL)、KL 散度和余弦相似度的数学原理及其在深度学习中的应用场景。通过示例代码详细阐述了这些概念的 PyTorch 实现方式,帮助读者在理论与实践之间建立联系。
Abstract
This article summarizes the theoretical foundations and implementation methods of commonly used loss functions and similarity metrics in PyTorch, focusing on the mathematical principles of cross-entropy, entropy, negative log-likelihood (NLL), KL divergence, and cosine similarity, as well as their applications in deep learning. By presenting detailed example code, the article bridges the gap between theory and practice, helping readers gain a comprehensive understanding.
1 Corss Entropy Loss
1.1 介绍
参数:
- weight:指定权重,(dim),可选参数,可以给每个类指定一个权重。通常在训练数据中不同类别的样本数量差别较大时,可以使用权重来平衡。
- ignore_index:指定忽略一个真实值,(int),也就是手动忽略一个真实值。
- reduction:在[none, mean, sum]中选,string型。
none
表示不降维,返回和target相同形状;mean
表示对一个batch的损失求均值;sum
表示对一个batch的损失求和
输入:
可以看到对于Target
,是有两种情况的。如果Target
是类别的索引,那么Target
的shape是比Input
的shape少了通道维;如果Target
是类别的概率值,那么Target
的shape是与Input
的shape相同。具体可以看官网。
1.2 代码举例
import torch
import torch.nn as nn
import torch.nn.functional as F
# logits shape:[BS, NC]
batch_size = 2
num_class = 7
logits = torch.randn(batch_size, num_class) # input unnormalized score
target_indices = torch.randint(num_class, size=(batch_size,)) # delta目标分布,是整形的索引 shape:(2,)
target_logits = torch.randn(batch_size, num_class) # 非delta目标分布,概率分布 shape:[2,7]
ce_loss_fn = nn.CrossEntropyLoss() # 实例化
## method 1 for CE loss
ce_loss = ce_loss_fn(logits, target_indices)
print(f"cross entropy loss1: {
ce_loss}")
## method 2 for CE loss
ce_loss = ce_loss_fn(logits, torch.softmax(target_logits, dim=-1))
# 将target_logits 进行softmax在通道维上进行归一化成概率分布,和为1
print(f"cross entropy loss2: {
ce_loss}")
输出结果均为标量:
cross entropy loss1: 3.269336700439453
cross entropy loss2: 2.0783615112304688
2 Negative Log Likelihood loss (NLL loss)
2.1 介绍
对于负对数似然Negative Log Likelihood loss(NLL loss)来说,input
是每个类别的log-probabilities,而target
只能为类别索引。实际上,能用cross entropy loss的地方就能用negative log-likelihood loss,这在后面的代码部分进行进一步验证。
2.2 代码举例
nll_fn = nn.NLLLoss()
nll_loss = nll_fn(torch.log(torch.softmax(logits, dim=-1)), target_indices)
print(f"negative log-likelihood loss: {
nll_loss}")
沿用上面初始化的logits
,对其在通道维上进行softmax得到归一化概率值,再取对数,从而获得batch_size个样本的每个class的log-probabilities。target_indices
是类别索引。
输出结果为标量:
negative log-likelihood loss: 3.269336700439453
如果沿用上面CELoss的logits
和target_indices
的初始化值,可以看到cross entropy loss1和negative log-likelihood loss的输出结果是相同的。说明cross entropy loss = s o f t m a x + l o g + n l l l o s s =softmax+log+nll loss =softmax+log+nllloss
3 Kullback-Leibler divergence loss (KL loss)
3.1 介绍
P、Q相当于两个系统,KL距离的定义为:
D K L ( P