1. 基本语法介绍
torch.reshape() 是 PyTorch 中用于改变张量形状的函数。它返回一个新张量,其数据与输入张量相同,但具有指定的形状。如果可能,该函数将返回输入张量的视图(即不复制数据),但如果无法返回视图,则会返回数据的副本。
改变形状时,元素不能变
2. 代码解释
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("原始张量 x: ")
print(x)
print("x的形状:", x.shape)
输出:
原始张量 x:
tensor([[1, 2, 3],
[4, 5, 6]])
x的形状: torch.Size([2, 3])
y = torch.reshape(x, (3, 2))
print("形状改变后的张量 y:")
print(y)
print("y的形状:", y.shape)
输出:
形状改变后的张量 y:
tensor([[1, 2],
[3, 4],
[5, 6]])
y的形状: torch.Size([3, 2])
z = torch.reshape(x, (1, 6))
print("形状改变后的张量 z:")
print(z)
print("z的形状:", z.shape)
输出:
形状改变后的张量 z:
tensor([[1, 2, 3, 4, 5, 6]])
z的形状: torch.Size([1, 6])
# reshape((-1