问题:
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 ‘mat2’
异常代码行:
prediction = net.forward(b_x)
解决方法:
在代码前添加一行,将输入的数据转成是dtype=torch.float32的。
b_x = torch.tensor(b_x, dtype=torch.float32)
prediction = net.forward(b_x)
参考链接
https://github.com/pytorch/pytorch/issues/2138、
https://pytorch.org/docs/stable/tensors.html?highlight=float#torch.Tensor.float
# if you use the DataLoader you can avoid any miscommunication regarding built-in data structures for ATen
X = torch.from_numpy(X).double(