输入的图像通道数可以在网络中拆分,也可以在网络中合并,这样就能实现特征融合的多样性。
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
x.chunk(2, dim=1)
是使x在1的维度上均等的拆分为2份。