1、关于torch.stack做拼接的时候,无论是按照那个维度进行拼接,其结果都会进行扩维的,下面来看具体的例子:
import torch
a=torch.tensor([[1,2,3],[4,5,6]])
b=torch.tensor([[7,8,9],[10,11,12]])
c=torch.stack((a,b), 2)
print(c)
d=torch.stack((a,b),1)
print(d)
e=torch.stack((a,b),0)
print(e)
结果如下:
tensor([[[ 1, 7],
[ 2, 8],
[ 3, 9]],
[[ 4, 10],
[ 5, 11],
[ 6, 12]]])
tensor([[[ 1, 2, 3],
[ 7, 8, 9]],
[[ 4, 5, 6],
[10, 11, 12]]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
2、当使用cat进行拼接的时候,具体的例子如下:
import torch
a=torch.tensor([[1,2,3],[4,5,6]])
b=torch.tensor([[7,8,9],[10,11,12]])
c=torch.cat((a,b),0)
print(c)
d=torch.cat((a,b),1)
print(d)
结果如下:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
总结:
从上面的两个例子中我们是可以看出这两个API的不同,前者是无论你是以哪一个维度来进行拼接,都会改变当前的维度,但是后者不会改变当前的维度。以上两者通常会在深度学习中使用。