Pyroch中nn.Sequential的多输入问题
nn.Sequential为何不能处理多输入?
先上两个小例子直观感受一下:
class MyLayer(nn.Module):
def __init__(self):
super(MyLayer, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 1, 1, bias=False)
self.conv2 = nn.Conv2d(3, 32, 1, 1, bias=False)
def forward(self, x1, x2):
y1 = self.conv1(x1)
y2 = self.conv2(x2)
return y1+y2
x = torch.rand(1, 3, 32, 32)
model = MyLayer()
y = model(x, x)
print(y.shape)
上面这段代码是可以正常运行的,即一般我们使用nn.Module建立model,多输入是支持的。但是,当我们使用nn.Sequential封装我们的MyLayer时,多输入就不会被支持。继续上面的例子:
model = nn.Sequential(MyLayer())
x = torch.rand(1, 3, 32, 32)
y = model(x, x)
print(y.shape)
运行这段代码时就会出现如下错误:
y = model(x, x)
File “E:\anacon