目录
序言:
张量是什么,相信大家经过上一节的介绍,大家都有所了解了,接下来我们开始讲解张量的拼接操作,相信各位小伙伴在第一次接触张量的拼接操作的时候,小脑袋都是懵懵的,别怕,有我在,接下来我会给你详细讲解一下张量的拼接操作
进行张量拼接操作的函数torch.cat()
首先我们来看一下官网对这个函数是怎样描述的
知道各位小伙伴不喜欢看英语,我翻译一下
将所给的张量序列按照指定的维度进行连接,所有的张量除了被连接的维度之外必须具有相同的形状或者为空.
torch.cat()可以视为torch.split()和torch.chunk的一种反向操作
torch.cat()可以通过例子很好的理解
参数:
张量(张量序列):任意具有相同的python数据类型的张量序列,当所提供的张量不为空时,除了要连接的维度,其余维度必须有相同的形状
dim(整数):要连接的维度
关键字参数:
输出(张量):输出的张量
上面的torch.split()与torch.chunk()函数有感兴趣的小伙伴可以去了解一下,在此不做赘述。
看完官网对torch.cat()函数的描述之后,我相信大家虽然对这一个函数大致有了个了解,但是心中仍然会有疑惑,别急,我们接着往下讲
维度(轴):
我们先看一下下面这一张图片
对于这一个矩阵我们会说他是一个二维的,此时这个维指的是横向与纵向两个轴,维度指的是轴的数量
那我们再看下一个例子,向量
我们会说这是一个4维的向量,此时维度的意义表示其沿轴0的长度,
为了防止大家对此产生疑惑,我们使用轴这个概念来代替维度,特别地,矩阵含有两个轴,向量具有一个轴
上述torch.cat()函数所沿维度进行拼接其实是沿轴进行拼接。读到这里你可能会疑惑,虽然知道了沿轴进行拼接,但是哪一个轴是轴0哪一个是轴1呢?
下面我们来看一下具体是怎样对其进行拼接的,以及怎么知道哪一个轴是轴0,哪一个轴是轴1?
实例:
首先我们创建两个张量x与y
# 导入包
import torch
# 创建张量x与张量y
x = torch.tensor([[[3, 5, 4, 1], [4, 1, 2, 8], [9, 7, 5, 0]]])
y = torch.zeros((3, 3, 4))# 分别打印张量x与y和他们的形状
print(x.shape)
print(x)
print(y.shape)
print(y)
下面是创建出的两个张量
可以看出x的形状是(1,3, 4),y的形状是(3, 3, 4)
提前给你剧透一下哪一个是轴0,哪一个是轴1,如图所示、
对于轴数更多的张量,仍然是这样的,
你可能会说,你说的就对?我不信。 先别急我给你演示一下
下面我们将其沿轴0进行拼接
# 将x与y张量沿轴0进行拼接
z = torch.cat((x, y), dim=0)
print(z)
它拼接之后的内容就应该是图中标出的部分进行拼接,我们来看看结果
确实就是沿着轴0将内容进行拼接了
接下来我们试着将x与y沿着轴1进行拼接,你们猜一下会发生什么?
答案是无法进行拼接,原因如图
如果你们还是不理解我们可以将三个轴都画出来,演示一下拼接操作
沿轴0的拼接是进行下面的操作
我们可以看到是可以进行拼接的,拼接后的张量正是上面我们演示的按轴0拼接后的张量
这样是不是沿轴1无法拼接就明白了,因为y中蓝色与黄色的在x中没有对应的拼接位置,如果能进行拼接,拼接后的张量是没有固定的规模的,也就是说没法说torch.Size,举个例子将x与沿轴0进行拼接,拼接后的张量,我们可以说形状是(4, 3, 4),但是沿轴1拼接后我们没法这样说。
同理沿轴2进行拼接会出现与沿轴1拼接一样的问题,也就是说在不沿轴0进行拼接的时候,因为x与y沿轴0的长度不同导致无法进行拼接。同理如果不是轴0的长度不同,而是沿轴1的长度不同,在沿轴0或轴2进行拼接的时候仍然会出现上述问题。
总结
这就是为什么,torch.cat()要求除了要连接的轴之外其他的轴必须具有相同的形状。
我是apprentice_eye,一个致力于让知识变的易懂的博主
小伙伴们,点个关注再走吧!!!