最近在阅读swin transformer的代码,在其PatchEmbedding模块发现了一个很少见的操作,Linear的输入竟然是四维(见下图1,swin transformer解读原文见:CV+Transformer之Swin Transformer - 知乎 (zhihu.com)),和以往常见的二维输入完全不一样,因此自己写代码测试了一下:
得出的结论如下:
pytorch中,Linear类的这种全连接层只会对输入Tensor的最后一维进行计算,论据如下图:
在此处我生成了一个(1,3,224,224)的四维tensor,将其输入一个为维度为(224,512)的全连接层计算后得到了一个维度为(1,3,224,512)的向量。这个用法几乎很少见,因此做个记录。