Python,Pytorch构建双输入网络,具有两个输入端,对序列进行二分类。
本程序是两边输入均是一维序列,有基础也可改为一边图像一边序列,或两边都是图像。
如何基于Pytorch构建双输入的网络是曾经长时间困扰本人的问题,现已弄明白。
整体工作如下:
1、加载数据集,inputdata1.csv里是400x500的数据,即样本数400x序列长度500。inputdata2.csv里是400x250的数据,即样本数400x序列长度250。inputdata1和inputdata2中样本一一对应,每行是一个样本,同属一个标签。在这400个样本中,前200个为0类,后200个为1类。
2、划分训练集测试集。双输入数据处理中,这部分比较麻烦些,需要将dataset弄成为(input_data1, input_data2, label)格式,然后再用troch的random_split函数随机划分训练集(80%)和测试集(20%)。
3、构建双输入网络,示例起见,两边均为简单的Conv网络。