pytorch doc介绍
简单来说就是增加一个新的维度,在这个新的维度上进行torch.cat()操作。
import torch
a = torch.randn(3,4)
# tensor([[-0.0974, -1.3577, -0.5162, -0.9748],
# [-1.0509, -0.7450, -0.7226, -1.6895],
# [-0.7616, 1.0055, 0.5779, -2.0157]])
b = torch.randn(3,4)
# tensor([[-0.3454, 1.2769, -0.3882, -1.4049],
# [-0.3809, -1.2949, -0.6149, 1.1036],
# [ 0.9674, 1.2621, 1.7883, -0.7552]])
c = torch.cat((a, b), dim=0)
# tensor([[-0.0974, -1.3577, -0.5162, -0.9748],
# [-1.0509, -0.7450, -0.7226, -1.6895],
# [-0.7616, 1.0055, 0.5779, -2.0157],
# [-0.3454, 1.2769, -0.3882, -1.4049],
# [-0.3809, -1.2949, -0.6149, 1.1036],
# [ 0.9674, 1.2621, 1.7883, -0.7552]])
c.shape
# torch.Size([6, 4])
d = torch.stack((a,b), dim=0)
# tensor([[[-0.0974, -1.3577, -0.5162, -0.9748],
# [-1.0509, -0.7450, -0.7226, -1.6895],
# [-0.7616, 1.0055, 0.5779, -2.0157]],
# [[-0.3454, 1.2769, -0.3882, -1.4049],
# [-0.3809, -1.2949, -0.6149, 1.1036],
# [ 0.9674, 1.2621, 1.7883, -0.7552]]])
d.shape
# torch.Size([2, 3, 4])
在学习过程中还发现torch.stack()可以将list列表转换成tensor。
import torch
list = []
for i in range(3):
list.append(torch.randn(1, 2, 2))
# lsit is a list formed by three tensors
# [tensor([[[ 0.9005, -0.1368],
# [-0.1858, -0.8703]]]),
# tensor([[[-1.9890, -0.3895],
# [-0.2434, 0.8056]]]),
# tensor([[[ 0.0505, 0.9198],
# [-1.9719, 2.0393]]])]
list = torch.stack(list)
# list is a tensor of shape(3,1,2,2)
# tensor([[[[ 0.9005, -0.1368],
# [-0.1858, -0.8703]]],
# [[[-1.9890, -0.3895],
# [-0.2434, 0.8056]]],
# [[[ 0.0505, 0.9198],
# [-1.9719, 2.0393]]]])