class Partial_conv3(nn.Module):
def __init__(self, dim, ouc, n_div=4, forward='split_cat'):
super().__init__()
self.dim_conv3 = dim // int(n_div) # 确保 n_div 是整数
self.dim_untouched = dim - self.dim_conv3
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
self.conv = Conv(dim, ouc, k=1)
if forward == 'slicing':
self.forward = self.forward_slicing
elif forward == 'split_cat':
self.forward = self.forward_split_cat
else:
raise NotImplementedError
def forward_slicing(self, x):
# only for inference
x = x.clone() # !!! Keep the original input intact for the residual connection later
x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
x = self.conv(x)
return x
def forward_split_cat(self, x):
# for training/inference
x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
x = self.conv(x)
return x
网站上找的不好使,我传个我自己能用的,仅当参考