ERROR
使用pytorch的函数 torch.nn.CrossEntropyLoss()计算Loss时报错
或者
loss = criterion(output, target)
报错:
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed
解决方法:
原因一:模型输出与分类数不一致
看模型的输出尺寸与分类数差异是否明显,核查代码是否存在错误。
如果没有错误,只是映射维度不对,可以考虑在模型的最后一层加一层FC层,将输出尺寸