torch reshape函数理解
torch.reshape
在变形时会遵循以下步骤:
-
将张量的所有元素视为一维连续数据:无论原始张量的维度如何,
torch.reshape
都会将其看成一个一维的线性数据流。这并不实际改变张量的存储顺序(即内存布局),而是改变它的形状视图。 -
按照新的形状重新组织元素:根据指定的新形状,重新组织这些元素。要求新形状的元素数量总数必须等于原张量的元素总数(即
torch.numel
的值保持不变)。
例如:
import torch
# 创建一个 2x3 的张量 (2,3)
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 将x看成一维数据流为 1,2,3,4,5,6. 一维数据流的公式化解释,以形状(2,3)为例:
# 也就是(0,0)(0,1)(0,2)(1,0)(1,1)(1,2)。有点类似于嵌套for循环的意思,先固定
# 0维不变,将1维从小变到大,再将0维+1后固定,将1维从小变到大。
# 使用 reshape 将其变成 3x2 的张量
y = torch.reshape(x, (3, 2))
# 将(0,0)(0,1)(0,2)(1,0)(1,1)(1,2)变为形状(3,2),将每两个划为一行,即:
# (0,0)(0,1)|(0,2)(1,0)|(1,1)(1,2)
print(y)
输出:
tensor([[1, 2],
[3, 4],
[5, 6]])
注意点:
-
内存连续性:
torch.reshape
通常会尝试返回与原张量共享内存的视图(view)。但如果无法共享内存,例如张量的存储不是连续的,reshape
可能会创建一个新张量。 -
元素总数不变:
torch.reshape
要求新形状的元素数量与原张量一致,否则会报错。例如:z = torch.reshape(x, (4, 2)) # 报错:原张量的元素数是 6,无法变成 8
如果你只是希望将张量展平成一维,可以使用 torch.flatten
或 torch.reshape(x, (-1,))
。
与 view
的区别:
torch.view
和 torch.reshape
功能类似,但 view
要求输入张量在内存中是连续的,而 reshape
可以处理非连续张量。