拼接
维度顺序:对于 3D 张量,通常可以理解为 (深度, 行, 列) 或 (批次, 行, 列)。 选择一个dim进行拼接的时候其他两个维度大小要相等
对于三维张量,理解 torch.cat
的 dim
参数确实变得更加抽象,但原理是相同的。让我们通过一个具体的例子来说明这一点。
import torch
# 创建两个 3D 张量
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
print("Tensor a shape:", a.shape)
print(a)
print("\nTensor b shape:", b.shape)
print(b)
# dim=0 连接
c_dim0 = torch.cat([a, b], dim=0)
print("\nResult of torch.cat([a, b], dim=0):")
print("Shape:", c_dim0.shape)
print(c_dim0)
# dim=1 连接
c_dim1 = torch.cat([a, b], dim=1