在训练模型过程中有如下语句,结果出现报错
output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
RuntimeError: gather_out_cuda(): Expected dtype int64 for index
原因应该是和pytorch版本有关系,根据报错,找到原因为target.unsqueeze(2)的dtype是int32类型的,把它转为int64了,可以正常运行。
在使用PyTorch训练模型时,遇到了一个RuntimeError,问题出在gather操作上,因为targetunsqueeze(2)的dtype是int32而不是expected的int64。将target的dtype转换为int64后,报错消除,程序得以正常运行。
在训练模型过程中有如下语句,结果出现报错
output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
RuntimeError: gather_out_cuda(): Expected dtype int64 for index
原因应该是和pytorch版本有关系,根据报错,找到原因为target.unsqueeze(2)的dtype是int32类型的,把它转为int64了,可以正常运行。
您可能感兴趣的与本文相关的镜像
PyTorch 2.6
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
397