pytorch模型输入两个参数

最近用Pytorch需要用到两个输入,发现有很多朋友也没有找到解决方法,今天来分享一下我的做法。
**

forward:

**

    def forward(self, x1, x2):
        feature = self.conv1(x1)
        out = torch.cat([feature, x2], dim=1)

        out1 = self.layer1(out)
        out2 = self.layer2(out1)

        out = torch.cat([out1, out2], dim=1)
        out3 = self.layer3(out)

        out = torch.cat([out1, out2, out3], dim=1)
        out = self.layer4(out)

        out = out + x1

        return out

只需要在原来forward(self,x)再加一个参数即可
**

get_data中可以用下列方法分成input1和input2

**

def split_input(input, band=12):
    # 输入为四维数据,第一维是batch,第二维才是band
    input1 = input[:, band:band + 1]
    # 注意这里要使用band:band+1才不会改变input的维度
    input2 = np.delete(input, band, axis=1)
    return input1, input2

这里的input是dataset返回的Inputs
**

torchsummary的测试

**

summary(model.cuda(), [[1, 30, 30], [24, 30, 30]], batch_size=16)

**

tensorboard的测试

**

    dummy_input1 = torch.randn(16, 1, 30, 30)
    dummy_input2 = torch.randn(16, 24, 30, 30)

    with SummaryWriter(comment='RDNet') as w:
        w.add_graph(model,(dummy_input1,dummy_input2))

以上测试均能成功。
注意目前tensorboardx还不支持pytorch1.4.0要回退到1.3.0,graph才能正常生成。

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值