torch.from_numpy VS torch.Tensor

本文详细解析了PyTorch中torch.Tensor与torch.from_numpy两种转换方式的差异,强调了后者在非float类型转换时的安全性及内存共享特性。
部署运行你感兴趣的模型镜像

torch.from_numpy VS torch.Tensor

最近在造dataset的时候,突然发现,在输入图像转tensor的时候,我可以用torch.Tensor直接强制转型将numpy类转成tensor类,也可以用torch.from_numpy这个方法将numpy类转换成tensor类,那么,torch.Tensortorch.from_numpy这两个到底有什么区别呢?既然torch.Tensor能搞定,那torch.from_numpy留着不就是冗余吗?

答案

有区别,使用torch.from_numpy更加安全,使用tensor.Tensor在非float类型下会与预期不符。

解释

实际上,两者的区别是大大的。打个不完全正确的比方说,torch.Tensor就如同c的inttorch.from_numpy就如同c++的static_cast,我们都知道,如果将int64强制转int32,只要是高位转低位,一定会出现高位被抹去的隐患的,不仅仅可能会丢失精度,甚至会正负对调。这里的torch.Tensortorch.from_numpy也会存在同样的问题。

看看torch.Tensor的文档,里面清楚地说明了,

torch.Tensor is an alias for the default tensor type (torch.FloatTensor).

torch.from_numpy的文档则是说明,

The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.

也即是说,

  1. 当转换的源是float类型,torch.Tensortorch.from_numpy会共享一块内存!且转换后的结果的类型是torch.float32
  2. 当转换的源不是float类型,torch.Tensor得到的是torch.float32,而torch.from_numpy则是与源类型一致!

是不是很神奇,下面是一个简单的例子:

import torch
import numpy as np

s1 = np.arange(10, dtype=np.float32)
s2 = np.arange(10) # 默认的dtype是int64

# 例一
o11 = torch.Tensor(s1)
o12 = torch.from_numpy(s1)
o11.dtype # torch.float32
o12.dtype # torch.float32
# 修改值
o11[0] = 12
o12[0] # tensor(12.)

# 例二
o21 = torch.Tensor(s2)
o22 = torch.from_numpy(s2)
o21.dtype # torch.float32
o22.dtype # torch.int64
# 修改值
o21[0] = 12
o22[0] # tensor(0)

您可能感兴趣的与本文相关的镜像

PyTorch 2.9

PyTorch 2.9

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

torch.tensor torch.from_numpy 都是用于在 PyTorch 中创建张量的函数,但在数据类型转换共享内存方面存在一些区别 [^1]。 ### 区别 - **数据类型转换**:torch.tensor 函数返回的是 torch.Tensor 类型对象,数据类型可通过 dtype 参数指定;而 torch.from_numpy 创建的张量会继承 NumPy 数组的数据类型,若要转换为单精度浮点数类型,需使用 `.float()` 方法 [^3][^5]。 - **内存共享**:torch.tensor 是深拷贝,即创建的张量与原数据在内存中是独立的;torch.from_numpy 是浅拷贝,创建的张量与原 NumPy 数组共享同一块内存,对其中一个的修改会反映到另一个上 [^3]。 ### 使用方法 - **torch.tensor**:可接受多种数据类型,如列表、元组、NumPy 数组等,通过 dtype 参数指定数据类型,还可指定 device 参数。示例代码如下: ```python import torch import numpy as np # 从列表创建张量 tensor_from_list = torch.tensor([1, 2, 3], dtype=torch.float32) print(tensor_from_list) # 从 NumPy 数组创建张量 numpy_array = np.array([4, 5, 6]) tensor_from_numpy_array = torch.tensor(numpy_array, dtype=torch.int64) print(tensor_from_numpy_array) ``` - **torch.from_numpy**:仅接受 NumPy 数组作为输入。若数据是通过 numpy 的加载函数(如 `np.loadtxt()`、`np.load()` 等)从文件(如 `.csv`、`.npy` 文件)中读取的,或者在 numpy 操作过程中生成的,且数据类型是 `numpy.ndarray`,可使用该函数将其转换为 PyTorch 张量 [^4]。示例代码如下: ```python import torch import numpy as np numpy_array = np.array([7, 8, 9]) tensor_from_numpy = torch.from_numpy(numpy_array) print(tensor_from_numpy) ``` ### 相关知识 判断是否适用 torch.from_numpy 的方法有以下几种:可以通过数据来源存储形式,若数据以 numpy 数组形式存在则适用;使用 `type()` 函数检查数据类型,若为 `numpy.ndarray` 适用;查看数据处理流程,若数据一直以 numpy 数组在不同函数或模块间传递,在传给 PyTorch 相关操作前适用 [^4]。将 NumPy 数组转换为 PyTorch 张量后,使用 `.float()` 方法可将张量的数据类型转换为 `torch.float32` [^5]。
评论 11
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值