loss_func = nn.CrossEntropyLoss()
这个交叉熵损失函数输入的两个参数的shape并不是相同的,一个是各个类别分别的概率(类似于独热码),另一个是位置下标的阿拉伯数字
另外:
例如:
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(