基于PyTorch的改良Transformer模型用于多维时间序列分类

Gated-Transformer-on-MTS: 基于PyTorch的改良Transformer模型用于多维时间序列分类

项目概述

在这里插入图片描述

该项目提出了一种改良的Transformer架构(Gated Transformer)用于处理多维时间序列(MTS)分类任务。通过双塔结构和门控机制,模型能够同时捕捉时间步(step-wise)和通道(channel-wise)之间的关系,并在多个基准数据集上取得了优于传统CNN基线(FCN和ResNet)的性能。

核心创新点

  1. 双塔结构:同时计算step-wise和channel-wise的注意力机制

    • Step-wise塔:保留传统Transformer的位置编码和mask机制
    • Channel-wise塔:去除位置编码和mask,专注于通道间关系
  2. 门控机制:自适应融合双塔输出

    h = W · Concat(C, S) + b \\
    g1, g2 = Softmax(h) \\
    y = Concat(C · g1, S · g2)
    

在这里插入图片描述

  1. 仅使用Encoder:针对分类任务简化模型结构

实验结果

在13个多元时间序列数据集上的分类准确率对比:

数据集FCNResNetGated Transformer
ArabicDigits99.499.698.8
AUSLAN97.597.497.5
CharacterTrajectories99.099.097.0
CMUsubject1610099.7100
ECG87.286.791.0
JapaneseVowels99.399.298.7
Libras96.495.488.9
UWave93.492.691.0
KickvsPunch54.051.090.0
NetFlow89.162.7100
PEMS--93.6
Wafer98.298.999.1
WalkvsRun100100100

技术实现细节

数据预处理

  • 处理不等长时间序列:使用零填充至最大时间步长
  • 创建PyTorch Dataset和DataLoader对象
  • 特殊处理NetFlow数据集的标签(1和13→0和1)
    在这里插入图片描述

模型架构

  • 输入处理:线性层将原始输入映射到d_model维
  • 双塔注意力
    • Step-wise塔:带位置编码和mask的多头注意力
    • Channel-wise塔:无位置编码的多头注意力
  • 门控融合:学习自适应权重融合双塔特征
  • 分类头:全连接层输出分类结果
    在这里插入图片描述

超参数配置

{
    "d_model": 512,          # 模型维度
    "d_hidden": 2048,        # FFN隐藏层维度
    "q": 64,                 # Query维度
    "v": 64,                 # Value维度
    "h": 8,                  # 注意力头数
    "N": 6,                  # Encoder层数
    "dropout": 0.1,          # Dropout率
    "EPOCH": 100,            # 训练轮数
    "BATCH_SIZE": 32,        # 批大小
    "LR": 1e-4,              # 学习率
    "optimizer": "Adam"      # 优化器
}

项目结构

Gated-Transformer-on-MTS/
├── dataset_process.py       # 数据集处理
├── module/                  # 模型模块
├── utils/                   # 工具类
│   ├── random_seed.py       # 随机种子设置
│   ├── heatMap.py           # 热力图绘制
│   ├── visualization.py     # 训练曲线可视化
│   └── TSNE.py              # 降维聚类可视化
├── run.py                   # 训练脚本
├── run_with_saved_model.py  # 测试脚本
├── saved_model/             # 模型保存目录
└── result_figure/           # 结果图目录

使用说明

  1. 环境配置

    • Python 3.7
    • PyTorch ≥1.6
    • 支持CPU/GPU
  2. 数据集准备

    • 从百度云下载.mat格式数据集
    • 路径: https://pan.baidu.com/s/1u2HN6tfygcQvzuEK5XBa2A (提取码: dxq6)
  3. 训练模型

    python run.py --dataset ECG --d_model 512 --h 8 --N 6
    
  4. 测试模型

    python run_with_saved_model.py --dataset ECG --model_path saved_model/ECG_best_model.pkl
    

可视化工具

项目提供了多种可视化工具:

  • 注意力权重热力图(对比DTW和欧氏距离)
  • 训练过程中的loss/accuracy曲线
  • t-SNE降维聚类图
  • 自定义折线图绘制

注意事项

  1. 模型保存使用PyTorch 1.6+格式,加载时需兼容版本
  2. 数据集路径和结果保存路径不建议修改
  3. 可视化工具中的颜色映射可能需要根据具体数据集调整

该项目为多维时间序列分类任务提供了一种有效的Transformer改良方案,通过创新的双塔结构和门控机制,在多个数据集上展现了优越性能。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值