LSTNet--结合时间注意力机制的LSTM模型(附源码)

本文介绍了LSTNet模型,它是通过融合注意力机制提升LSTM在时间序列预测能力的一种尝试。LSTNet包含LST-Skip和LST-Atten两种模型,其中LST-Atten能自动捕捉序列周期。文章详细阐述了LST-Atten的工作原理,并提供了模型的PyTorch实现。在公共交通流量数据集上的实验表明,LST-Atten在周期序列预测中表现出良好性能,但对无明显周期性的序列效果有限。

一、引言

        LSTM出现以来,在捕获时间序列依赖关系方面表现出了强大的潜力,直到Transformer的大杀四方。但是,就像我在上一篇博客《RNN与LSTM原理浅析》末尾提到的一样,虽然Transformer在目标检测、目标识别、时间序列预测等各领域都有着优于传统模型的表现,甚至是压倒性的优势。但Transformer所依赖的Multi-Head Attention机制给模型带来了巨大的参数量与计算开销,这使得模型难以满足实时性要求高的任务需求。我也提到,LSTM想与Transformer抗衡,似乎应该从注意力机制方面下手。事实上,已经有研究这么做了,那就是LSTNet。

二、LSTNet

        2018年,论文《Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks》正式提出了LSTNet。LSTNet的出现可以认为是研究人员通过注意力机制提升LSTM模型时序预测能力的一次尝试,文中共提出了LST-Skip与LST-Atten两种模型。其中,LST-Skip需要手动设置序列的周期,比较适用于交通流预测等周期明确可知的时间序列,而LST-Atten模型则可以自动捕捉模型的周期。实验表明,上述两种模型在周期序列预测中表现出了良好的性能。

        然而,上述模型的性能受制于序列的周期与可用历史状态的长度。首先,模型的注意力机制为“时间注意力机制”,其本质是利用了序列内部的时间性周期,因此对于没有明显周期性的序列(如:车辆轨迹序列)则不能很好地发挥优势。其次,LST-Atten依赖于历史状态挖掘序列的周期,若可用的历史状态较短,无法反映一个完整的周期,则模型可自主挖掘周期性的优势仍无法体现。

三、方法        

        本文以实现LST-Atten为例(在进入FC层前的张量处理与原文稍有不同),描述LSTM中的时间注意力机制。由于在优快云的编辑器中不方便使用各种专业符号,因此下文中使用的符号一切从简,不以专业性为目的。

        我们假设使用过去的A帧数据预测未来的B帧,且LSTM编码器中A个LSTM单元的隐状态为H,LSTM解码器中第一个LSTM单元的隐状态为h1。那么集成了时间注意力机制的LSTM编码-解码器工作原理可用如下表示:

         其中,F为打分函数,用于计算H与h1之间的余弦相关度。然后,通过softmax函数,这些余弦相关度被转换为各历史隐状态的相对权重。H中各时刻隐藏状态的加权和与解码器第一帧输出的隐藏状态h1相

评论 79
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值