【代码复现】STAEformer


前言

官方Github
论文:STAEformer: Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting

STAEformer复现结果能够对的上论文实验结果,并且配置环境时没有遇到雷点。

一、创建虚拟环境

我创建了一个名称为STAEformer的虚拟环境(名称可以更改),并conda activate进入虚拟环境。

conda create -n STAEformer python==3.9.18
conda activate STAEformer

二、安装cuda、pytorch

官方代码仓要求pytorch>=1.11,这里我在虚拟环境安装cuda11.7和pytorch1.13.1及其相关依赖。

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia

三、Required Packages

依赖包依次安装就完了。

pip install numpy
pip install pandas
pip install matplotlib
pip install pyyaml
pip install torchinfo

四、复现结果

cd model/进入项目model文件夹下,运行

python train.py -d <dataset>

官方很贴心地把数据集都给整理好了,不用自己去找了。
<dataset>:

  • METRLA
  • PEMSBAY
  • PEMS03
  • PEMS04
  • PEMS07
  • PEMS08

复现METRLA为例,测试结果为:

--------- Test ---------
All Steps RMSE = 5.93878, MAE = 2.92491, MAPE = 8.00552
Step 1 RMSE = 3.97519, MAE = 2.26739, MAPE = 5.49113
Step 2 RMSE = 4.65860, MAE = 2.49998, MAPE = 6.28363
Step 3 RMSE = 5.09899, MAE = 2.65260, MAPE = 6.84915
Step 4 RMSE = 5.45347, MAE = 2.76543, MAPE = 7.31178
Step 5 RMSE = 5.72631, MAE = 2.86540, MAPE = 7.71518
Step 6 RMSE = 5.96758, MAE = 2.95223, MAPE = 8.09941
Step 7 RMSE = 6.20411, MAE = 3.02974, MAPE = 8.39245
Step 8 RMSE = 6.37668, MAE = 3.09545, MAPE = 8.68591
Step 9 RMSE = 6.52967, MAE = 3.15964, MAPE = 8.96343
Step 10 RMSE = 6.69013, MAE = 3.21868, MAPE = 9.21720
Step 11 RMSE = 6.82990, MAE = 3.27001, MAPE = 9.42466
Step 12 RMSE = 6.95626, MAE = 3.32242, MAPE = 9.63251
Inference time: 4.86 s
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值