Pytorch定义网络结构的时候出错:TypeError: new() received an invalid combination of arguments - got (float, int), but expected one of:
问题分析
首先,我的错误复现:
>>> import torch.nn as nn
>>> fc1 = nn.Linear(512*3, 512/2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/zhangboshen/anaconda3/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 48, in __init__
self.weight = Parameter(torch.Tensor(out_features, in_features))
TypeError: new() received an invalid combination of arguments - got (float, int), but expected one of:
* (torch.device device)
* (torch.Storage storage)
* (Tensor other)
* (tuple of ints size, torch.device device)
didn't match because some of the arguments have invalid types: (float, int)
* (object data, torch.device device)
didn't match because some of the arguments have invalid types: (float, int)
其实上面这一段代码,放到python2.7里面定义是完全没有问题的,问题是我现有的python3的环境对于除法的规则和python2不一样。
上面的512/2,在python3中的结果是256.0,但是在python2中就是256,python2和3的这些区别,真的让人处处感到绝望。。。
解决方法,把/换成//,后者的运算能够保证是int型数据。
说到python2和3的坑,还有一个容易出现的问题是字符串,python3中读出来之后默认是byte类型,这会导致你读出来的字符串前面多出一个b,比如b'015601864.jpg'这样的,接下来,就会导致opencv读不到图片然后引起一系列连锁反应随后你会怀疑自己写的代码是如此的垃圾然而并不是的,是python2和3的差异导致的。。。。
解决方法是在 Python3 中,bytes 和 str 互相转换:
str.encode('utf-8')
bytes.decode('utf-8')

本文解析了在PyTorch中定义网络结构时遇到的TypeError问题,源于Python3与Python2除法规则差异。介绍了如何通过使用整数除法(//)解决此问题,并提及了Python3中字符串读取的常见误区及解决方案。
4956





