代码如下:
import torch
logits = torch.randn(3,4,requires_grad=True)
labels = torch.LongTensor([1,0,2])
print('logits={}, labels={}'.format(logits,labels))
# 直接计算交叉熵(cross entropy loss)
def calc_ce_loss1(logits, labels):
ce_loss = torch.nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)
return loss
# 分解计算交叉熵(cross entropy loss = log softmax + nll loss)
def calc_ce_loss2(logits, labels):
log_softmax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()
logits_ls = log_softmax(logits)
loss = nll_loss(logits_ls, labels)
return loss
loss1 = calc_ce_loss1(logits, labels)
print('loss1={}'.format(loss1))
loss2 = calc_ce_loss2(logits, labels)
print('loss2={}'.format(loss2))
# 增加 temperature
temperature = 0.05
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))
temperature = 2
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))
输出如下:
logits=tensor([[-0.7441, -2.3802, -0.1708, 0.5020],
[ 0.3381, -0.3981, 2.2979, 0.6773],
[-0.5372, -0.4489, -0.0680, 0.4889]], requires_grad=True), labels=tensor([1, 0, 2])
loss1=2.399930000305176
loss2=2.399930000305176
t=0.05, loss1=35.99229431152344
t=0.05, loss2=35.99229431152344
t=2, loss1=1.8117588758468628
t=2, loss2=1.8117588758468628