时间序列预测的高级深度学习架构
1. 可解释的N - BEATS预测
1.1 N - BEATS工作原理
N - BEATS基于两个主要组件:
- 双重残差连接栈:涉及预测和反向预测。反向预测指重建时间序列的过去值,它能促使模型从两个方向理解时间序列结构,从而学习更好的数据表示。
- 深度密集连接层栈:这两个组件的结合使模型兼具高预测精度和可解释性。
1.2 模型训练、评估和使用流程
模型的训练、评估和使用工作流程遵循PyTorch Lightning提供的框架,具体步骤如下:
1. 数据准备 :数据准备逻辑在数据模块组件中开发,特别是在 setup() 函数内。
2. 建模阶段 :
- 定义N - BEATS模型架构:使用 from_dataset() 方法根据输入数据直接创建NBeats实例。
- 定义训练过程逻辑:在 Trainer 实例中定义训练过程逻辑,包括所需的回调函数。部分回调函数(如提前停止)会将模型的最佳版本保存到本地文件,训练后可加载使用。
代码示例:
from lightning.pytorch.tuner import Tuner
import lightning.pytorch as pl
from pytorch_forecasting import NBeats
# 假设datamodule已定义
超级会员免费看
订阅专栏 解锁全文

被折叠的 条评论
为什么被折叠?



