一、摘要
本文介绍谷歌团队发表于2023年的论文《A decoder-only foundation model for time-series forecasting》
译文:
受自然语言处理(NLP)中大型语言模型最新进展的启发,我们设计了一种用于预测的时间序列基础模型,其在各种公共数据集上的开箱即用零样本性能接近于每个单独数据集的最先进监督预测模型的准确性。我们的模型基于预训练一个带有输入分片的解码器风格注意力模型,使用一个包含真实世界和合成数据集的大型时间序列语料库。在一组多样化的先前未见过的预测数据集上的实验表明,该模型可以在不同领域、预测范围和时间粒度上产生准确的零样本预测。
二、核心创新点
作者指出,时间序列预测的基础模型应当能够适应可变的上下文和预测范围长度,同时具有足够的容量来编码来自大型预训练数据集的所有模式。因此,作者采用了经过实践验证的Transformer架构作为基础,并加入了几个特定于时间序列的设计选择:
- 分片(patching):在训练期间,作者将时间序列分解为一个个的patch。由于输入到Transformer中的Token数量被patch长度的因子减小了,使得推理速度得到了提升。
- 仅解码器模型(Decoder-only):论文中的模型以仅解码器的模式进行训练。给定一系列的输入patches,经过优化之后,模型可以根据所有过去的patch来预测下一个patch。
- 更长的输出patch:作者允许用于预测的输出patch比输入的patch更长。例如,假设当前输入patch长度为32,输出的patch长度可以是128。
- patch掩码:作者认为,如果只是直接使用patch进行训练,模型可能只会针对输入patch长度的倍数的上下文长度进行预测。因此在每个data batch中,可以对上下文窗口开头的部分patch以及整个patch进行掩码Mask,且训练期间采用特定的随机掩码策略,这将有助于模型看到从1开始到最大上下文长度的所有可能上下文长度。
1、输入层
输入层的作用是将时间序列预处理为输入Token,以供Transformer使用。首先,将输入分割为连续的且不重叠的patch片段,然后每个patch有一个残差块处理成大小为model_dim的向量。与输入一致,作者还提供了一个二进制的padding掩码,其中1表示
中对应的应被忽略的输入。换句话说,输入
被分割为大小为input_path_len(p)的patch片段集合,而第 j 个patch可以表示为
,同样地掩码也可以分片为
,由此第 j 个输入的Token可以表示为:
其中,表示第 j 个位置编码。
2、堆叠Transformer层
堆叠Transformer层有着标准的多头自注意力机制,其后是全连接层FFN。主要的超参数是模型的维度(model_dim),它等于输入Token(设为)的维度以及头的数量(num_heads),作者将FFN的隐藏层大小也设置为与模型维度相同。此外,作者使用了因果注意力机制,即每个输出Token只能关注序列中在它之前的输入Token,这可以用方程表示:
对于所有的,
是第 j 个Token的掩码指示符,定义为
,即如果有一个patch有任何non-masked的时间点,则相应的Token被标记为未Masked,所有完全被掩码的patch不会被因果自注意力机制关注。
3、输出层
TimesFM以仅解码器的模式进行训练,即每个输出Token都应该能够预测其对应的最后一个输入patch之后的时间序列部分。然而,时间序列基础模型与主流大模型的区别在于,输入patch的长度不必等于输出patch的长度,也就是说我们应该能够基于到目前为止看到的输入patch的编码信息来预测更大块的时间序列,设输出patch的长度为output_patch_len(h),作者使用另一个残差块将输出Token映射到预测上:
因此,作者将的所有数据编码为
,并使用它来预测随后的h时间点
。这在一个训练的mini-batch中对所有的patch都适用。
4、损失函数
论文指出,TimesFM专注于点的预测(point forecasting),因此训练期间使用基于点的损失,例如均方误差MSE,训练期间最小化的损失可以表示为:
5、训练
TimesFM训练的方式遍历时间序列的所有窗口以及跨时间序列。唯一非标准的部分是作者在训练期间采样掩码的方式。对于一个batch中的每个时间序列,作者在0到p-1之间随机采样一个数字 r,然后令,其余设为0。也就是说这个方式mask了第一个输入patch的一部分,然而这足以覆盖从1到最大训练上下文长度的所有输入上下文长度。
举个例子,假设最大上下文长度为512,且p是32。如果r=4,那么在看到第一个patch(来自)之后,输出预测被优化为在看到28=32-4个时间点后进行预测,下一个patch(来自
)的输出被优化为在看到28+32个时间点后进行预测,以此类推。当对所有这样的 r 重复这个做法时,模型已经看到了直到512的所有可能的上下文。