在使用pytorch深度学习框架,计算分类损失时经常会遇到这么一个函数nn.CrossEntropyLoss(),该损失函数结合了nn.LogSoftmax()和nn.NLLLoss()两个函数。它在做分类(具体几类)训练的时候是非常有用的,如下我将对该函数的原理使用代码和实例进行分析。
首先输入是size是(minibatch,C)。这里的C是类别数。损失函数的计算如下:

损失函数中也有权重weight参数设置,若设置权重,则公式为:

注意这里的标签值class,并不参与直接计算,而是作为一个索引,索引对象才为实际类别。
举个栗子,我们一共有5种类别,批量大小为1(为了好计算),那么输入size为(1,3),具体值为torch.Tensor([[-0.7715, -0.6205,-0.2562]])。标签值为target = torch.tensor([0]),这里标签值为0,表示属于第0类。loss计算如下:
举个栗子,我们一共有5个类别,batch_size =2,
import torch
import torch.nn as nn
import math
entroy=nn.CrossEntropyLoss()
input=torch.tensor([[ 0.0043, -0.0174, -0.0153,

博客围绕使用Pytorch深度学习框架计算分类损失展开,介绍了结合nn.LogSoftmax()和nn.NLLLoss()的损失函数,说明了其输入格式、权重参数设置,还通过具体代码和实例分析了该函数的原理及损失计算过程。
最低0.47元/天 解锁文章
1861





