最近几天,由于要复现一篇论文,但是论文中的公式太过复杂,层层嵌套,本人表示无法求出导数来,遂决定借助框架的力量来进行操作,尝试过tensorflow,但是其静态图的模式调试起来过于麻烦,不利于随时随地取出数据来验证,于是开始转攻pytorch,用来一段时间,感觉真的很好用,但是在编写过程中也遇到很多的坑,于是决定将这些坑分享出来,以待后来人。
pytorch和tensorflow一样的一点在于他们接受的能够处理的数据类型是张量,将我们在python中常用的list,array等数据类型转换为张量是使用框架至关重要的一部,如果你连合格的数据都提供不了,就不用说接下来的学习了,下面贴出python常用类型与张量tensor之间的互换函数:
Tensor---->Numpy 可以使用 data.numpy(),data为Tensor变量
Numpy ----> Tensor 可以使用torch.from_numpy(data),data为numpy变量
如果遇到类型为list的数据可以先转换为array类型的再用torch.from_numpy转换为tensor类型
# torch.
long
() 将tensor投射为
long
类型
long_tensor = tensor.
long
()print(long_tensor)
# torch.half()将tensor投射为半精度浮点类型
half_tensor = tensor.half()print(half_tensor)
# torch.int()将该tensor投射为int类型
int_tensor = tensor.
int
()print(int_tensor)
# torch.double()将该tensor投射为double类型
double_tensor = tensor.
double
()print(double_tensor)
# torch.float()将该tensor投射为float类型
float_tensor = tensor.
float
()print(float_tensor)
# torch.char()将该tensor投射为char类型
char_tensor = tensor.
char
()print(char_tensor)
# torch.byte()将该tensor投射为byte类型
byte_tensor = tensor.
byte
()print(byte_tensor)
# torch.short()将该tensor投射为short类型
short_tensor = tensor.
short
()print(short_tensor)