Python,Pytorch使用一维CNN-Transformer,对一维序列进行分类源程序。
程序旨在学习如何构建CNN-Transformer网络,以及如何转换数据维度使得CNN的输出能够衔接Transformer,本程序是将CNN的输出通道数直接匹配作为Transformer的维度。CNN可以提取空间特征,Transformer则提取长时序列特征,模型创新性较强。
在使用此程序时,建议先大致了解Transformer框架的基本结构:Transformer模型中有Encoder和Decoder模块。参考了许多使用Transformer做分类的程序,模型中均是只使用了Encoder模块。本程序仅使用了Transformer的Encoder模块,没有用Decoder。且没有用Embedding,因为考虑到需要级联CNN,Embedding不好写进去,而且序列也没有明确的位置信息,可以不用Embedding。
程序工作如下:
1、加载数据。原始数据为Excel,400条1*500的序列(心电信号),其中200条正常,200条异常。
2、构建CNN-Transformer模型。其中,CNN用了2层,Transformer_Encoder用了6层,里面nhead=4。
3、训练、测试。显示训练集准确率和Loss变化,计算测试集Acc、P