- 先实现交叉熵并验证
import torch
import torch.nn as nn
import torch.nn.functional as F
# bchw
logits1 = torch.rand((2,3,2,2), dtype=float, requires_grad=True)
# bhw
label1 = torch.tensor([1,2,1,1,2,0,2,0]).view(2,2,2)
logits2 = logits1.clone()
label2 = torch.tensor([1,2,