1. 出错代码行
计算交叉熵是出现异常提示:RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1549635019666/work/aten/src/THNN/generic/ClassNLLCriterion.c:21
loss = criterion(prediction, target)
2. 原因:
CrossEntropyLoss does not expect a one-hot encoded vector as the target, but class indices
pytorch 中计计算交叉熵损失函数时, 输入的正确 label (target)不能是 one-hot 格式。所以只需要输入数字 4 就行,不需要输入 one hot 格式的 [ 0 0 0 0 1]。函数内部会自己处理成 one hot 格式。
loss = criterion(prediction, target)
print(prediction.size()