pytorch中torch.stack()函数
一、基本功能
将若干个张量在dim维度上连接,生成一个扩维的张量,比如说原来你有若干个2维张量,连接可以得到一个3维的张量。
设待连接张量维度为n,dim取值范围为-n-1~n,这里得提一下为负的意义:-i为倒数第i个维度。举个例子,对于2维的待连接张量,-1维即3维,-2维即2维。
a=torch.tensor([[1,2,3],[4,5,6]])
b=torch.tensor([[10,20,30],[40,50,60]])
c=torch.tensor([[100,200,300],[400,500,600]])
print(torch.stack([a,b,c],dim=0))
print(torch.stack([a,b,c],dim=1))
print(torch.stack([a,b,c],dim=2))
print(torch.stack([a,b,c],dim=0).size())
print(torch.stack([a,b,c],dim=1).size())
print(torch.stack([a,b,c],dim=2).size())
#输出结果为:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 10, 20, 30],
[ 40, 50, 60]],
[[100, 200, 300],
[400, 500, 600]]])
tensor([[[ 1, 2, 3],
[ 10, 20, 30],
[100, 200, 300]],
[[ 4, 5, 6],
[ 40, 50, 60],
[400, 500, 600]]])
tensor([[[ 1, 10, 100],
[ 2, 20, 200],
[ 3, 30, 300]],
[[ 4, 40, 400],
[ 5, 50, 500],
[ 6, 60, 600]]])
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])
torch.Size([2, 3, 3])
二、规律分析
通过代码运行结果,我们不难发现,stack(tensors,dim=0,out=None)函数的运行机制可以等价为:
- dim=0时,将tensor在一维上连接,简单来说就是,就是将tensor1,tensor2…tensor n,连接为【tensor1,tensor2… tensor n】(就是在这里产生了扩维)
- dim=1时,将每个tensor的第i行按行连接组成一个新的2维tensor,再将这些新tensor按照dim=0的方式连接
- dim=2时,将每个tensor的第i行转置后按列连接组成一个新的2维tensor,再将这些新tesnor按照dim=0的方式连接
本文介绍了PyTorch中的torch.stack()函数,该函数用于在指定维度上连接多个张量,创建一个三维张量。根据dim参数的不同,连接方式有所不同:dim=0时在第一维度连接,dim=1时按行连接,dim=2时按列连接。通过实例展示了不同dim值下的输出结果及其尺寸变化。
1214

被折叠的 条评论
为什么被折叠?



