Tensor相关的不同维度变换操作
维度变换操作
Tensor中涉及维度的操作主要包括:
- 重排维度:
transpose
,swapaxes
,permute
,这些操作不会改变维度的个数,只会改变各个维度的顺序,例如(dim0,dim1,dim2,dim3) 变成 (dim2,dim3,dim0,dim1)。 - 重组维度:
view
、reshape
,这些操作,是会改变维度个数的,因为可以对维度进行重新排列,shape中不仅维度个数可能会改变,每一个维度的长度也会改变。
重排维度
使用transpose
在PyTorch中,假设我们有一个形状为 2,3,42,3,4 的张量,我们想要交换最后两个维度。
import torch
x = torch.randn(2, 3, 4)
x_transposed = x.transpose(1, 2) # 交换维度1和维度2
print(x_transposed.shape) # 输出: torch.Size([2, 4, 3])
使用permute
如果我们想要重新排列维度,例如将形状为 2,3,42,3,4 的张量重新排列为 4,2,34,2,3。
x_permuted = x.permute(2, 0, 1) # 重新排列维度
print(x_permuted.shape) # 输出: torch.Size([4, 2, 3])
重组维度
张量的维度变换常见的方法有torch.view()
和torch.reshape()
,两个的区别:
torch.view()
是直接在原tensor上进行变换,返回的tensor和原tensor共用一块地址;
torch.reshape()
返回的新tensor不是和原tensor共用一块地址。但是此函数并不能保证返回的是其拷贝值,所以官方不推荐使用。推荐的方法是我们先用 clone()
创造一个张量副本然后再使用 torch.view()
进行函数维度变换 。
注:使用
clone()
还有一个好处是会被记录在计算图中,即梯度回传到副本时也会传到源 Tensor 。
view方法就是指定一个shape来改变tensor的形状;同时他还可以自动推断维度大小;
import torch
# 创建一个1x4的Tensor
x = torch.randn(1, 4)
print(f"Original Tensor:\n{
x}\n")
# 使用view方法改变形状为2x2
y = x.view(2, 2)
print(f"Reshaped Tensor (2x2):\n{
y}"