最近在阅读swin transformer的代码,在其PatchEmbedding模块发现了一个很少见的操作,Linear的输入竟然是四维(见下图1,swin transformer解读原文见:CV+Transformer之Swin Transformer - 知乎 (zhihu.com)),和以往常见的二维输入完全不一样,因此自己写代码测试了一下:

得出的结论如下:
pytorch中,Linear类的这种全连接层只会对输入Tensor的最后一维进行计算,论据如下图:

在此处我生成了一个(1,3,224,224)的四维tensor,将其输入一个为维度为(224,512)的全连接层计算后得到了一个维度为(1,3,224,512)的向量。这个用法几乎很少见,因此做个记录。
博主在阅读SwinTransformer代码时发现Linear层被用于四维输入,这并不常见。通过实验验证,发现在PyTorch中Linear层只对输入Tensor的最后维度进行计算。输入(1,3,224,224)经过(224,512)的全连接层后得到(1,3,224,512)的输出。这一观察对于理解Transformer模型的实现具有参考价值。
1195

被折叠的 条评论
为什么被折叠?



