在pytorch的官方文档中,KLDivLoss的input和target的shape要保持一致,即如果真实标签(target) 采用如’0、1、2、3…'这种数字标签编码时,要将数字标签编码转为ont-hot编码,即
[0, 1, 2, 3,....]===>[[0, 0, 0, 0,....],
[0, 1, 0, 0,....],
[0, 0, 1, 0,....],
................]]
然后将input和target送入KLDivLoss进行计算。
但是计算CrossEntropyLoss时,target和input的shape可以是不同的,Pytorch官方的代码中对于target支持两种输入,一种是类标签输入,一种是类的概率输入。因此可以target可以以数字标签的形式输入,也可以以one-hot的形式输入。

2262

被折叠的 条评论
为什么被折叠?



