在pytorch中,常见的拼接函数主要是两个,分别是:
- stack()
- cat()
torch.stack()
函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。
形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。
1 stack()
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。
outputs = torch.stack(inputs, dim=?) → Tensor
参数
inputs : 待连接的张量序列。
注:python的序列数据只有list和tuple。
dim : 新的维度, 必须在0到len(outputs)之间。
注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。
2 重点
- 函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等
----举例:[tensor_1, tensor_2,…]或者(tensor_1, tensor_2,…),且必须tensor_1.shape == tensor_2.shape - dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小
不懂的看例子,再回过头看就懂了。
3 例子
- 准备2个tensor数据,每个的shape都是[3,3]
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90