类似于numpy中的reshape,且不改变原始数据。重点是,它是从最后一个维度进行reshape中的
import torch
a = torch.randint(0, 10, size=(3, 4))
print(a)
print(a.view(4, 3))
print(a.view(1, 12))
分别得到
tensor([[7, 1, 7, 2],
[9, 5, 7, 1],
[6, 0, 9, 9]])
tensor([[7, 1, 7],
[2, 9, 5],
[7, 1, 6],
[0, 9, 9]])
tensor([[7, 1, 7, 2, 9, 5, 7, 1, 6, 0, 9, 9]])