N-BEATS:独特的可解释时间序列预测深度学习模型
介绍
在金融、零售和气象等各个领域,时间序列预测至关重要。传统模型通常在灵活性和可解释性方面存在问题,尤其是在处理复杂模式时。N-BEATS 是一种突破性的模型,由 Element AI 和蒙特利尔学习算法研究所 (MILA) 的研究人员开发。N-BEATS 以其可解释性和强大的预测能力而闻名,是真正的游戏规则改变者。
什么是 N-BEATS?
N-BEATS(时间序列预测的神经基础扩展分析)彻底改变了时间序列预测的方法。与依赖于循环神经网络 (RNN) 的典型模型不同,N-BEATS 采用一系列前馈神经网络。这种结构不仅提高了性能,而且还避免了 RNN 经常出现的复杂性和不稳定性。
N-BEATS 的主要功能
可解释性:N-BEATS 的一大亮点是其可解释性,它提供了哪些数据组件影响预测的见解——这是深度学习模型中罕见的功能。
模块化:该架构包含多个块,可以以各种方式配置以适应不同的应用程序,从而无需改变核心框架即可进行广泛的定制。
泛化:N-BEATS 旨在处理多种时间序列数据,可以适应不同的数据集而无需进行特定调整。
深入探究架构
块结构:N-BEATS 的核心由多个块组成,每个块负责捕获特定数据模式,例如趋势或季节性。这些块预测和预测回测值,帮助模型关注时间序列的不同方面。
堆叠块:N-BEATS 的优势在于堆叠这些块。每个块层通过处理残差(实际值与先前预测之间的差异)来细化预测。这种细化提高了最终预测的准确性。
N-BEATS 的工作原理
训练:N-BEATS 通过交替预测未来值和重建过去值(回溯)进行训练。它最大限度地减少了预测和实际数据之间的误差,从而提高了预测能力。
预测:在预测中,N-BEATS 汇总所有块的预测,确保在最终预测中考虑到从总体趋势到特定季节性模式的每个数据方面。
例如:每日温度数据
为了演示 N-BEATS,我们使用每日温度数据集,这是时间序列建模中的常见基准。它是一个周期为一年(365 天)且振幅为 10 的正弦波。此示例将展示如何使用 PyTorch 实现 N-BEATS 并解释模型的输出以了解不同时间序列组件的影响。
让我们探索练习的主要部分:(完整代码的链接位于参考部分)
步骤 1:创建合成温度数据集的函数
def create_temperature_dataset ( length, num_samples ):
np.random.seed( 0 )
x = np.linspace( 0 , length, num_samples)
seasonal = 10 + 10 * np.sin( 2 * np.pi * x / 365 )
noise = np.random.normal( 0 , 2 , num_samples)
y = seasonal + noise
return x, y
创建数据集
x, y = create_temperature_dataset( 365 , 3650 )