torch.cat() 和 torch.stack() 函数的作用都是将多个维度参数相同的张量连接成一个张量,不同之处在与 stock()相比于cat()多了一维。这里两个函数都有 dim 这个参数,但是指的意思却不一样。使用下图来解释,在这里将两个张量理解成树这种形式,希望可以帮助理解。
这里竖线代表一个维度,竖线上所有节点代表同一维度的所有元素,在下面所有图中,同颜色的元素都是按照从上往下按顺序排列的。
dim在cat()函数中表示索所要连接的维度,也就是连接 所要连接的多个张量 的这个维度上面的参数。
但是在stack()中,dim表示多出来的维度,这个维度被用来连接之后维度的参数。原来的维度则变成子节点了,例如dim=1,那么 原来张量的第一维度 就变成了 连接之后的张量 的第二维度
假设这里一个torch.randn(2, 3, 4)生成的两个张量,如下图

红和蓝分别表示两个不同的张量,后面所有的图中左边的是使用stack()函数,右边是使用cat()函数,黄色的表示stack()函数生成的多的一维。
那么当 dim = 0时,如下图
dim = 1, 如下图

dim = 2,如下图

对于stack()函数生成的结果会多一个维度,所有在这个例子中会有3这个索引值所代表的第四维度,dim = 3是成立的,但是对于cat()函数则没有这个

本文详细介绍了PyTorch中torch.cat()和torch.stack()函数的使用,它们都用于连接张量,但区别在于torch.stack()会在原始张量的维度之间插入新的一维,而torch.cat()则是在指定维度上直接拼接。通过实例展示了当dim分别为0、1、2时,两个函数如何影响张量的形状,帮助读者深入理解这两个函数的工作原理。
1598

被折叠的 条评论
为什么被折叠?



