TimesFM
论文
A decoder-only foundation model for time-series forecasting
模型结构
TimesFM是一种基于区块的decoder-only模型,应用了自注意力机制和传统的位置编码,主要由三个组件组成:输入层、Transformer层和输出层。
1、输入层:将时间序列数据分割成相等长度的时序数据块(patch),然后通过残差块对每个时序数据块进行线性变化,进而得到Token。
2、Transformer层:应用了位置编码和自注意力机制。位置编码将时间信息注入Token(令牌)序列;自注意力允许模型学习序列中不同标记之间的依赖关系和关系;位置编码介入自注意力的构造意味着模型可以适应数据中不同的时间粒度和频率。
3、输出层:使用层归一化和残差连接,将输出Token映射到最终预测。
算法原理
TimesFM是google research推出的一种时序预测基础模型,在真实世界的大型语料库上进行了预训练。TimesFM能够适应不同的上下文和预测长度,并且与最新的LLM相比体量更小(200M参数),同时在未见过的数据集上也能zero-shot预测。具体的,TimesFM对时间序列进行分块和位置编码注入,再通过堆叠的Transformer层提炼出数据中的时间顺序信息和不同时间点的关系。
环境配置
注意:本仓库中的tensorflow只应用于长期基准测试读取数据集,运行源码发生了OOM错误,目前认为是tf-gpu与jax-gpu的内存分配冲突导致:GPU memory allocation — JAX documentation。解决方法:
1、安装tf-gpu,但在timesfm/experiments/long_horizon_benchmarks/data_loader.py中添加:
tf.config.experimental.set_visible_devices([], "GPU")
2、安装tf-cpu
-v 路径、docker_name和imageID根据实际情况修改
Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04-py310
docker run -it --network=host --privileged=true --name=timesfm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ <imageID> /bin/bash # <imageID>为以上拉取的docker的镜像ID替换
cd /your_code_path/timesfm
pip install praxis==1.2.0
pip install paxml==1.2.0
pip install -r requirements.txt
wget https://cancon.hpccube.com:65024/directlink/4/tensorflow/DAS1.0/tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
wget https://cancon.hpccube.com:65024/directlink/4/pytorch/DAS1.0/torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
pip install torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
Dockerfile(方法二)
docker build --no-cache -t timesfm:latest .
docker run -it --network=host --privileged=true --name=timesfm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ timesfm /bin/bash
cd /your_code_path/timesfm
pip install praxis==1.2.0
pip install paxml==1.2.0
pip install -r requirements.txt
wget https://cancon.hpccube.com:65024/directlink/4/tensorflow/DAS1.0/tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
wget https://cancon.hpccube.com:65024/directlink/4/pytorch/DAS1.0/torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
pip install torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
DTK软件栈:dtk24.04
python:python3.10
jax:0.4.23
tensorflow:2.13.1 (或 tensorflow-cpu==2.13.1)
torch:2.1.0
Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应
2、其他非特殊库直接按照下面步骤进行安装
cd /your_code_path/timesfm
pip install praxis==1.2.0
pip install paxml==1.2.0
pip install -r requirements.txt
pip install tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
pip install torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
数据集
基准测试数据集运行时会gluonts自动下载,长期基准测试数据集可通过scnet或官网地址手动下载(官网地址需要魔法):
- AIDatasets / project-dependency / timesfm_jax · 极狐GitLab
- https://drive.google.com/file/d/1alE33S1GmP5wACMXaLu50rDIoVzBM4ik/view?usp=share_link
下载完成后,将数据解压到datasets目录下,若有自订目录需求,可修改timesfm/experiments/long_horizon_benchmarks/run_eval.py:
DATA_DICT = {
"ettm2": {
"boundaries": [34560, 46080, 57600],
"data_path": "./datasets/ETT-small/ETTm2.csv", # 修改数据集存放路径
"freq": "15min",
},
...
}
长期基准测试数据集目录结构如下:
── datasets
│ ├── electricity
│ ├── electricity.csv
│ ├── ETT-small
│ ├── ETTh1.csv
│ ├── ETTh2.csv
│ ├── ETTm1.csv
│ └── ETTm2.csv
│ ├── exchange_rate
│ └── exchange_rate.csv
│ ├── illness
│ └── national illness.csv
│ ├── traffic
│ └── traffic.csv
│ └── weather
│ └── weather.csv
训练
官方暂未开放
推理
检查点可通过scnet或以下方式进行下载:
# "model"是存储目录,可自订
cd timesfm
export HF_DATASETS_CACHE="path/timesfm/model"
export HF_ENDPOINT=https://hf-mirror.com # 设置下载地址
huggingface-cli download --resume-download google/timesfm-1.0-200m --local-dir model
推理运行代码:
# 运行基准测试
# 由于基准测试未提供调用数据集的接口,测试集要手动更改timesfm/experiments/extended_benchmarks/run_timesfm.py:dataset_names内填入所需数据集name(当前仓库已设置测试所需数据,不需额外更改)
cd timesfm
sh train.sh
# 运行长期基准测试
cd timesfm
sh train_long.sh
result
无
精度
测试数据:
1、基准测试:
"ett_small_15min",
"traffic",
"m3_quarterly",
"m3_yearly",
"tourism_yearly"
2、长期基准测试:
"etth1" --预测长度pred_len:96 192 336
"ettm1" --预测长度pred_len:96 192 336
根据测试结果情况填写表格,表格内的性能指标取在各数据集上的均值:
基准测试:
device | mae | mase | scaled_crps | smape | time |
---|---|---|---|---|---|
DCU K100 | 16779.15057 | 1.64840008 | 0.106485143 | 0.096774573 | 1.639674425 |
GPU A800 | 16780.92672 | 1.648167181 | 0.106450716 | 0.096736371 | 1.473794842 |
长期基准测试:
device | wape | smape | time |
---|---|---|---|
DCU K100 | 0.742821495 | 0.515499999 | 2.125586112 |
GPU A800 | 0.742829698 | 0.515345775 | 1.908264995 |
应用场景
算法类别
时序预测
热点应用行业
交通,零售,金融,气象