最近在使用新版 Pytorch 0.4.0(听说1.0版本马上就要出了…)训练 GAN 的时候遇到了这样一个BUG,RuntimeError: Expected object of type torch.DoubleTensor but found type torch.cuda.FloatTensor for argument #2 ‘weight’ 。相信很多初学者都会在刚开始的时候遇到这样的问题,我就 debug 了一天的时间。首先我们看看问题本身,报错信息提示我们某个变量的类型错了,应该是 torch.cuda.FloatTensor 但是我们给的是 torch.DoubleTensor ,这个地方很容易理解反 这里说的是第二个参数 weight 的数据类型是 torch.cuda.FloatTensor,但应该是 torch.DoubleTensor,大家在前期学习深度学习基础理论的时候应该知道我们传入的数据是要和每层网络的 filter 进行卷积运算的,卷积核上的参数就是 weight,它的数据类型是你实例化你建立的模型的时候决定的(见下面的代码)。那么,这里的正确类型也就是 torch.DoubleTensor 是哪里来的呢?其实这个就是我们输入数据的数据类型,所以这里实际上是我们输入数据的类型错了而不是模型需要的类型错了,这也是这个报错令人疑惑的地方,它告诉我们一个错误但其实它是另一个错误引起的,确实有点反直觉。下面我们结合代码看看为什么会触发这个BUG。
以下是有关这个错误的部分代码:
main.py
cudnn.benchmark = True
device = torch.device('cuda:3')
G = Generator().to(device)
D = Discrimina