import numpy as np
import torch
a = np.arange(2 * 3 * 20 * 20).reshape(2, 3, 20, 20)
array1 = torch.tensor(a)
tensor_list = []
np_list = []
for i in range(4):
tensor_list.append(array1)
np_list.append(a)
dim0 = torch.stack(tensor_list, dim=0)
dim1 = torch.stack(tensor_list, dim=1)
dim2 = torch.stack(tensor_list, dim=2)
dim3 = torch.stack(tensor_list, dim=3)
dim4 = torch.stack(tensor_list, dim=4)
np_list = torch.tensor(np_list)
print(np_list.size())
print(dim0.size())
print(dim1.size())
print(dim2.size())
print(dim3.size())
print(dim4.size())
输出
# 变化前
torch.Size([4, 2, 3, 20, 20])
# 变化后
torch.Size([4, 2, 3, 20, 20])
torch.Size([2, 4, 3, 20, 20])
torch.Size([2, 3, 4, 20, 20])
torch.Size([2, 3, 20, 4, 20])
torch.Size([2, 3, 20, 20, 4])
相当于将dim指定的那个维度后面的都不动,然后将dim=0换到指定的dim上,剩下的dim=1到指定的dim向前移动一位。
假设size是【4,2,3,20,20】
令函数的dim=3,那么输出的size是【2,3,20,4,20】