Tensor类型的转换
Pytorch数据类型的转换可以通过三个方式:
1)调用Tensor成员函数long(),int(),double(),float(),byte()
2)调用Tensor成员函数type(),传入数据类型
3)调用Tensor成员函数type_as(),传入实例对象
示例文件 test.py
import torch
a = torch.randn(2, 3)
print(a.type())
b = a.int()
c = a.type(torch.LongTensor)
print(b.type())
print(c.type())
a.type_as(b)
print(a.type())
终端命令行及运行结果
<user>python test.py
torch.FloatTensor
torch.IntTensor
torch.LongTensor
torch.FloatTensor
Tensor类型与ndarray类型的转换
1)Numpy转化为Tensor:torch.from_numpy(ndarray)
2)Tensor转化为numpy:Tensor.numpy()
示例文件 test.py
import torch
import numpy as np
torch_a = torch.randn(3, 2)
numpy_a = torch_a.numpy()
print(torch_a)
print(type(numpy_a))
print(numpy_a)
numpy_b = np.array([[1, 2], [3, 4], [5, 6]])
torch_b = torch.from_numpy(numpy_b)
print(type(numpy_b))
print(numpy_b)
print(torch_b)
终端命令行及运行结果
<user>python test.py
tensor([[-0.2467, 0.4057],
[-1.3399, 1.4803],
[-0.6589, -0.0156]])
<class 'numpy.ndarray'>
[[-0.2466862 0.4057195 ]
[-1.3398814 1.4803079 ]
[-0.6588639 -0.01563702]]
<class 'numpy.ndarray'>
[[1 2]
[3 4]
[5 6]]
tensor([[1, 2],
[3, 4],
[5, 6]], dtype=torch.int32)
本文详细介绍了PyTorch中Tensor类型与NumPy数组类型之间的转换方法,包括Tensor内部类型转换如int(), double()等,以及通过type()和type_as()函数进行类型转换。同时,也阐述了如何使用torch.from_numpy()将NumPy数组转换为Tensor,以及如何使用Tensor.numpy()将Tensor转换回NumPy数组。
2623

被折叠的 条评论
为什么被折叠?



