【论文翻译】PredFormer:Transformer在时空预测任务中的有效性研究

image-20241025222117681

论文题目 PredFormer: Transformers Are Effective Spatial-Temporal Predictive Learners
论文链接 https://openreview.net/forum?id=avNVrQ8D2v
源码地址 https://github.com/yyyujintang/PredFormer (Coming soon)
关键词 Transformer、时空预测、Gated Transformer Blocks(GTB)

摘要

时空预测学习方法通常分为两类:基于循环的方式,这种方式在并行化和性能上面临挑战;以及无循环的方式,后者采用卷积神经网络(CNN)作为编码-解码架构。这些方法受益于强大的归纳偏置,但通常以可扩展性和泛化能力为代价。本文提出了PredFormer,一种纯基于Transformer的时空预测学习框架。受视觉Transformer(ViT)设计的启发,PredFormer利用精心设计的门控Transformer模块,经过对包括全注意力、因子分解以及交错时空注意力在内的3D注意力机制的全面分析。通过无循环的Transformer设计,PredFormer不仅简单且高效,在性能上大幅超越了以往的方法。对合成数据集和真实世界数据集的大量实验表明,PredFormer实现了最新的性能表现。在Moving MNIST数据集上,PredFormer相较于SimVP减少了51.3%的MSE。在TaxiBJ数据集上,该模型将MSE降低了33.1%,并将FPS从533提升到了2364。此外,在WeatherBench数据集上,PredFormer将MSE降低了11.1%,同时将FPS从196提高到了404。这些在精度和效率上的提升展示了PredFormer在实际应用中的潜力。源码和预训练模型将向公众开放。

1 介绍

时空预测学习通过基于过去的观测来预测未来的帧,从而学习空间和时间模式。这种能力在天气预报、交通流预测、降水预报和人体运动预测等应用中至关重要。

尽管各种时空预测学习方法取得了成功,但它们通常难以在计算成本和性能之间取得平衡。一方面,基于循环的强大方法依赖于自回归的RNN框架,然而这些方法在并行化和计算效率方面面临显著限制。另一方面,基于CNN的无循环方法虽然提高了效率,但由于局部感受野的限制,扩展性和泛化能力有限。这引发了一个更为基础的问题:我们是否可以开发一个框架,自主学习时空依赖关系,而不依赖归纳偏置?

image-20241025224403851

图2:时空预测学习框架的主要类别。(a) 基于循环的框架 (b) 基于卷积神经网络编码-解码的无循环框架 ( c) 基于纯Transformer的无循环框架。

一种直观的解决方案是直接采用纯Transformer结构,因为它是RNN的高效替代方案,且比CNN更具可扩展性。Transformer在视觉任务中已表现出显著的成功。尽管已有的方法试图将Swin Transformer集成到RNN框架中,或将MetaFormer作为时间翻译器集成到无循环的CNN编码器-解码器框架中,但纯Transformer架构仍然主要处于探索阶段,特别是在捕捉统一框架中的时空关系方面存在挑战。尽管将空间和时间维度合并并应用全注意力的概念在理论上是可行的,但由于注意力与序列长度成平方的扩展,使得这种方法在计算上非常昂贵。为了减少复杂性,最近的几种方法通过因子分解或交错方式分别处理空间和时间关系。

在这项工作中,我们提出了PredFormer,一种纯基于Transformer的时空预测学习架构。PredFormer深入分解空间和时间Transformer,通过与门控线性单元(GLU)的自注意力结合,更有效地捕捉复杂的时空动态。除了保留空间优先和时间优先配置的全注意力编码器和因子分解编码器策略外,我们还引入了六种新颖的交错时空Transformer架构,共产生九种配置。这种探索是为了应对不同任务和数据集的不同空间和时间分辨率及依赖关系。通过全面的调查,推动了当前模型的边界,并为时空建模设定了有价值的基准。

特别地,PredFormer在三个基准数据集上取得了最先进的性能,包括合成的移动物体预测、交通流预测和天气预报。在不依赖复杂模型架构或专用损失函数的情况下,PredFormer以较大幅度超越了以往的方法。此外,我们的最优模型在性能上表现出色,参数更少,FLOP更低,推理速度更快,展示了其在实际应用中的巨大潜力。

主要贡献如下:

  • 我们提出了PredFormer,一种纯基于门控Transformer的时空预测学习模型。通过消除CNN中固有的归纳偏置,PredFormer利用了Transformer的可扩展性和泛化能力,使其成为一个高度可适应的模型,显著提高了潜力和性能上限。
  • 我们对时空Transformer因子分解进行了深入分析,探索了全注意力编码器和因子分解编码器,以及交错时空Transformer架构,共得出了九种PredFormer变体。这些变体针对不同任务和数据集的空间和时间分辨率,优化了性能。
  • 据我们所知,PredFormer是第一个用于时空预测学习的纯Transformer模型。我们对从头开始在小数据集上训练ViT进行了全面研究,探索了正则化和位置编码技术。
  • 大量实验表明,PredFormer表现出卓越的性能。与SimVP相比,PredFormer在Moving MNIST上将MSE降低了51.3%,在TaxiBJ上降低了33.1%,同时将FPS从533提高到2364,在WeatherBench上将MSE降低了11.1%,并将FPS从196提高到404。

2 相关工作

image-20241025224237703

图1:(a) PredRNN、SimVP和PredFormer的性能表现;(b) 模型效率对比。图中位置越靠内的模型表示准确率和效率越高。

基于循环的时空预测学习

基于循环的时空预测模型的最新进展整合了CNN、ViT和Vision Mamba等结构到RNN中,采用多种策略来捕捉时空关系。ConvLSTM通过将卷积操作集成到LSTM框架中创新性地提出。PredNet利用自底向上的连接和自顶向下的连接来预测未来的视频帧。PredRNN引入了时空LSTM单元(ST-LSTM),通过传播隐藏状态有效地捕捉并记忆空间和时间表示。PredRNN++则通过引入梯度高速公路单元和因果LSTM来解决梯度消失问题,并自适应地捕捉时间依赖性。E3D-LSTM扩展了ST-LSTM的记忆能力,通过集成3D卷积。MIM模型进一步优化了ST-LSTM,重新设计了遗忘门,使用双循环单元并在隐藏状态之间利用差异信息。CrevNet使用基于CNN的可逆架构来有效地解码复杂的时空模式。PredRNNv2通过引入记忆解耦损失和课程学习策略来增强PredRNN。MAU设计了专门用于捕捉动态运动信息的运动感知单元。SwinLSTM则将Swin Transformer模块集成到LSTM架构中,而VMRNN扩展了该方法。

无循环的时空预测学习

最新的无循环模型,如SimVP,基于CNN编码器-解码器设计。TAU通过将时间注意力分离为静态帧内和动态帧间成分,并引入差异散度损失来监督帧间变化。OpenSTL集成了MetaFormer模型作为时间翻译器。此外,PhyDNet将物理原理引入CNN架构,而DMVFN则引入了动态多尺度体素流网络来增强视频预测性能。WAST提出了一种基于小波的损失函数。与之前的方法相比,PredFormer在其无循环的纯Transformer架构中通过利用全局感受野,提升了时空学习,超越了之前的模型表现。

视觉Transformer (ViT)

ViT在各种视觉任务中展示了出色的性能。在视频处理领域,TimeSformer研究了空间和时间自注意力的因子分解,并提出分别应用时间和空间注意力的方案以获得最佳准确率。ViViT探讨了因子分解编码器、自注意力和点积机制,得出的结论是优先应用空间注意力的因子分解编码器表现更好。TSViT发现优先应用时间注意力的因子分解编码器能够取得更好的结果。尽管取得了这些进展,目前大多数现有模型主要集中在视频分类领域,较少有研究将ViT应用于时空预测学习。PredFormer通过将自注意力与门控线性单元相结合,进一步深入分解时空Transformer,能够更强大地捕捉复杂的时空动态。

3 方法

为了系统地分析网络模型在时空预测学习中的Transformer结构,我们提出了PredFormer作为通用模型设计。

image-20241025230319810

图3:(a) PredFormer模型框架概述。(b) 从空间视角和时间视角的序列分解。( c) 门控Transformer模块。(d) 门控线性单元。

3.1 纯基于Transformer的架构

Patch Embedding:按照ViT的设计,PredFormer将帧序列 X X X 切分为大小为 p p p 的非重叠patch,生成序列 N = ⌊ H p ⌋ × ⌊ W p ⌋ N = \left\lfloor \frac{H}{p} \right\rfloor \times \left\lfloor \frac{W}{p} \right\rfloor N=pH×pW ,每个patch被扁平化为一维token。这些token被线性投影到隐藏维度 D D D ,并通过层归一化(LN)处理,生成张量 X ′ ∈ R B × T × N × D X' \in \mathbb{R}^{B \times T \times N \times D} XRB×T×N×D

Position Encoding:不同于典型的ViT方法,我们引入了二维时空位置编码(PE),该编码通过正弦函数生成并为每个patch分配绝对坐标。

PredFormer Encoder:这些一维token随后通过PredFormer编码器进行特征提取。PredFormer编码器由门控Transformer块以不同方式堆叠而成。

Patch Recovery:由于我们的编码器基于纯门控Transformer,不涉及卷积或分辨率减少,全球上下文在每一层建模。这允许其与简单的解码器配对,形成强大的预测模型。解码器将线性层作为解码器,将隐藏维度投影回去以恢复二维patch。

3.2 门控Transformer块

标准Transformer模型在多头注意力(MSA)和前馈网络(FFN)之间交替。每个头的注意力机制定义为:

Attention ( Q , K , V ) = Softmax ( Q K ⊤ d k ) V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V Attention(Q,K,V)=Softmax(dk QK

评论 7
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

holdoulu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值