Pytorch中nn.Conv2d数据计算模拟
最近在研究dgcnn网络的源码,其网络架构部分使用的是nn.Conv2d模块。在Pytorch的官方文档中,nn.Conv2d的输入数据为(B, Cin, W, H)
其中B为batch_size表示batch的大小,Cin为输入数据的特征大小(通道数),W、H对于图像数据来说分别表示图像数据的宽和高。输出数据为(B, Cout, W', H')
其中Cout表示输出的特征大小,W’, H’取决于W, H,具体转换方式如下图所示:
通过查询nn.Conv2d的源码可知,nn.Conv2d底层是由nn.functional.Conv2d实现的,所以可以使用nn.functional.Conv2d模拟nn.Conv2d操作。
# nn.Conv2d源码
def conv2d_forward(self, input, weight):
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,