【ICLR 2023】时间序列预测实战Crossformer(附代码+数据集+详细讲解)

本文对Crossformer模型进行实战讲解,它是用于多变量时间序列预测的深度学习模型。介绍了其网络结构,包括DSW嵌入、TSA层、HED结构,还提及数据集、参数、训练、预测等内容,最后说明训练个人数据集的修改方法。

论文地址:官方论文地址

代码地址:官方代码地址


 一、本文介绍

本篇文章给大家带来的实战讲解是Crossformer模型,其是一个针对多变量时间序列预测的新型深度学习模型,发表ICLR 2023上并且排名前5%,所以这个模型的质量还是能够有一定保证的(但是我用官方的代码真的是Bug一堆改的让人头大)。Bug多是很多但是其效果还是可圈可点的,Crossformer的主要思想是:通过维度-段式嵌入技术将时间序列数据转换为二维向量数组,同时使用两阶段注意力层来高效地捕获这两种依赖关系。Crossformer采用分层编码器-解码器结构,在不同层次上利用信息进行预测。

 专栏目录:时间序列预测目录:深度学习、机器学习、融合模型、创新模型实战案例

专栏订阅: 时间序列预测专栏:基础知识+数据分析+机器学习+深度学习+Transformer+创新模型

目录

 一、本文介绍

二、网络结构讲解

2.1 Crossformer的主要思想

2.2、维度-段式(DSW)嵌入

2.3、两阶段注意力(TSA)层

2.4、分层编码器-解码器(HED)结构

2.5、模型代码

三、数据集

四、参数讲解 

五、模型训练

六、配置代码 

七、模型预测

八、训练个人数据集 

8.1、修改一

8.2、修改二

九、全文总结 


二、网络结构讲解

2.1 Crossformer的主要思想

这个模型是一种新的基于Transformer的模型,名为Crossformer,这个模型在今年的ICLR上提出,是一种专门为多变量时间序列(MTS)预测设计。 Crossformer的主要特点包括:

1. 维度-段式(DSW)嵌入:这种新颖的嵌入技术将多变量时间序列数据沿每个维度划分为段,将这些段嵌入到特征向量中。这种方法保持了时间和维度信息,有助于模型更好地捕捉MTS数据的固有结构。

2. 两阶段注意力(TSA)层:Crossformer使用TSA层来有效捕捉时间和不同维度之间的依赖性。对于MTS预测来说,这两个方面的依赖性都是重要的。

3. 分层编码器-解码器(HED)结构:模型使用HED来利用不同规模的信息进行预测。这种分层方法有助于更有效地理解和预测MTS数据。 论文表明,通过其独特的方法,Crossformer有效地捕捉了跨维度依赖性,这是现有基于Transformer的MTS预测模型中常常忽视的一个关键方面。通过在六个真实世界数据集上的广泛实验结果显示,Crossformer在性能上超越了以前的最先进模型,表明了其有效性和实际应用的潜力。 这项研究通过解决现有模型的局限性并引入创新技术以提高性能,对时间序列预测领域做出了重要贡献。

下面我分别来解释这个模型中的三种结构->


2.2、维度-段式(DSW)嵌入

DSW嵌入是Crossformer模型的一个关键特性,它的目的是更好地捕捉MTS(多变量时间序列)数据中的跨维度依赖关系。传统的基于Transformer的模型主要关注于捕捉时间跨度上的依赖(即跨时间依赖),而往往没有显式地捕捉不同变量间的依赖性(即跨维度依赖),这限制了它们的预测能力。

在DSW嵌入中,每个维度的时间序列数据点被分成一定长度的段。然后,每个段被嵌入到一个向量中,方法是使用线性投影加上位置嵌入。线性投影矩阵E和位置嵌入Epos都是可学习的。这样,每个嵌入后的向量hid表示一个时间序列的单变量段,最终得到一个二维向量数组H。在这个数组中,每个向量hid代表一段时间序列的一维切片。与其他针对MTS预测的Transformer模型不同,DSW嵌入显式地捕捉了跨维度依赖性​。

上图展示了Crossformer模型中维度-段式(DSW)嵌入的概念:

a) 由在ETH1数据集上训练的双层Transformer模型得出的自注意力分数热图,展示了多变量时间序列(MTS)数据倾向于被分段。

b) 描述了之前Transformer基础模型的嵌入方法,这些模型将同一时间步的不同维度的数据点嵌入到一个向量中。

c) 展示了Crossformer的DSW嵌入:在每个维度中,相邻的时间点形成一个段进行嵌入。

总结:这个图解清晰地说明了Crossformer如何通过其DSW嵌入机制来处理MTS数据,这是该模型与以前的方法不同的地方,它保留了时间序列数据在不同时间步的维度信息,以便更好地捕捉跨维度的依赖性。


2.3、两阶段注意力(TSA)层

两阶段注意力(TSA)层是Crossformer模型的一个核心组成部分,用于捕捉嵌入数组中的跨时间和跨维度依赖性。具体来说,通过DSW嵌入,输入数据被嵌入到一个二维向量数组中,以保留时间和维度的信息。然后,TSA层被设计出来,用于捕捉这些嵌入数组的依赖性。

以下是TSA层的工作流程说明:

1. 跨时间阶段:

  • TSA层接收一个二维数组 Z 作为输入,这个数组可能是维度-段式(DSW)嵌入的输出或下层TSA层的输出。
  • 对于每个维度,直接应用多头自注意力(MSA)机制来捕捉同一维度内不同时间段之间的依赖关系。
  • 这一阶段的计算涉及到层归一化(LayerNorm)和多层感知机(MLP),这有助于处理自注意力机制的输出。
  • 此阶段的计算复杂度为 O(DL^2) ,其中  是段的数量,D 是维度的数量。

2. 跨维度阶段:

  • 为了避免直接在维度之间应用MSA所带来的O(D^2)的计算复杂度,提出了一种路由机制。
  • 为每个时间步设置了一小组可学习的向量,称为“路由器”,用于从所有维度聚集信息。
  • 这些路由器随后将聚合的信息分发到各个维度,有效地建立了维度之间的全连接,而没有高复杂度。
  • 路由机制显著降低了复杂度,从O(D^2L)减少到O(DL),通过限制需要考虑的连接数量。
  • 与跨时间阶段类似,跨维度阶段也使用层归一化和MLP来处理路由机制的输出

这种两阶段的方法使Crossformer能够高效地处理多变量时间序列数据中的复杂依赖关系,通过区别对待时间轴和维度轴,尊重它们在数据结构中的独特作用。

上图显示了两阶段注意力(TSA)层的构造和功能:

a)TSA层的整体结构,包含了跨时间阶段(Cross-Time Stage)和跨维度阶段(Cross-Dimension Stage),用于处理 O(2cD)=O(D)


2.4、分层编码器-解码器(HED)结构

分层编码器-解码器(HED)结构在Crossformer模型中用于多变量时间序列(MTS)预测,并能捕获不同尺度上的信息。HED结构包括编码器和解码器两个部分,它们按照以下步骤工作:

1. 编码器:

  • 除了第一层之外,编码器的每一层都会将时间域内两个相邻的向量合并,以获得更粗糙级别的表示。
  • 然后应用TSA层来捕获这个尺度上的依赖性。
  • 如果层数不是2的倍数,将进行填充以确保适当的长度。
  • 这个过程的输出表示为Zencl,它是编码器第 l 层的输出。
  • 编码器的每一层的复杂度是O(DT^2/L_{seg}^2)

2. 解码器:

  • 解码器接收编码器输出的N+1个特征数组,并使用N+1层(索引为0到N)进行预测。
  • 第 l 层取第 l 层编码的数组作为输入,然后输出解码的二维数组。
  • 解码过程中也使用了TSA层和多头自注意力机制(MSA),构建编码器和解码器之间的连接。
  • 解码器的每一层的复杂度是O(D\tau(T+\tau)/L_{seg}^2)

3. 最终预测:

  • 应用线性投影到解码器的每一层输出,以产生该层的预测。
  • 然后将所有层的预测相加,以得到最终的预测结果。

HED结构能够利用不同层次的信息进行预测,通过合并相邻的向量,并在不同的尺度上捕获依赖关系,最终通过解码器产生预测结果(其实看着好像考虑挺多但是结果我认为也就那样,对于时间序列领域我觉得往往简单才是真谛纯属个人见解哈哈)

上图展示了Crossformer模型中分层编码器-解码器(HED)结构的架构,其中包含3层编码器层次。每个向量的长度表示它所覆盖的时间范围。编码器(左侧)利用TSA层和段合并来捕捉不同尺度上的依赖关系:上层的一个向量覆盖了更长的时间范围,从而在更粗糙的尺度上产生依赖性。解码器(右侧)通过在每个尺度上进行预测并将它们相加来制作最终的预测。


2.5、模型代码

大家可以根据上面流程来缕一缕下面的代码应该会有一定的收获。

class Crossformer(nn.Module):
    def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4,
                factor=10, 
评论 86
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Snu77

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值