timesfm_jax一种预训练的时序预测基础模型

TimesFM

论文

A decoder-only foundation model for time-series forecasting

模型结构

TimesFM是一种基于区块的decoder-only模型,应用了自注意力机制和传统的位置编码,主要由三个组件组成:输入层、Transformer层和输出层。

1、输入层:将时间序列数据分割成相等长度的时序数据块(patch),然后通过残差块对每个时序数据块进行线性变化,进而得到Token。

2、Transformer层:应用了位置编码和自注意力机制。位置编码将时间信息注入Token(令牌)序列;自注意力允许模型学习序列中不同标记之间的依赖关系和关系;位置编码介入自注意力的构造意味着模型可以适应数据中不同的时间粒度和频率。

3、输出层:使用层归一化和残差连接,将输出Token映射到最终预测。

算法原理

TimesFM是google research推出的一种时序预测基础模型,在真实世界的大型语料库上进行了预训练。TimesFM能够适应不同的上下文和预测长度,并且与最新的LLM相比体量更小(200M参数),同时在未见过的数据集上也能zero-shot预测。具体的,TimesFM对时间序列进行分块和位置编码注入,再通过堆叠的Transformer层提炼出数据中的时间顺序信息和不同时间点的关系。

环境配置

注意:本仓库中的tensorflow只应用于长期基准测试读取数据集,运行源码发生了OOM错误,目前认为是tf-gpu与jax-gpu的内存分配冲突导致:GPU memory allocation — JAX documentation。解决方法:

1、安装tf-gpu,但在timesfm/experiments/long_horizon_benchmarks/data_loader.py中添加:
tf.config.experimental.set_visible_devices([], "GPU")
2、安装tf-cpu

-v 路径、docker_name和imageID根据实际情况修改

Docker(方法一)

docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04-py310
docker run -it --network=host --privileged=true --name=timesfm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ <imageID> /bin/bash  # <imageID>为以上拉取的docker的镜像ID替换

cd /your_code_path/timesfm
pip install praxis==1.2.0
pip install paxml==1.2.0
pip install -r requirements.txt
wget https://cancon.hpccube.com:65024/directlink/4/tensorflow/DAS1.0/tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
wget https://cancon.hpccube.com:65024/directlink/4/pytorch/DAS1.0/torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
pip install torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl

Dockerfile(方法二)

docker build --no-cache -t timesfm:latest .
docker run -it --network=host --privileged=true --name=timesfm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ timesfm /bin/bash

cd /your_code_path/timesfm
pip install praxis==1.2.0
pip install paxml==1.2.0
pip install -r requirements.txt
wget https://cancon.hpccube.com:65024/directlink/4/tensorflow/DAS1.0/tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
wget https://cancon.hpccube.com:65024/directlink/4/pytorch/DAS1.0/torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
pip install torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl

Anaconda(方法三)

1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/

DTK软件栈:dtk24.04
python:python3.10
jax:0.4.23
tensorflow:2.13.1 (或 tensorflow-cpu==2.13.1)
torch:2.1.0

Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应

2、其他非特殊库直接按照下面步骤进行安装

cd /your_code_path/timesfm
pip install praxis==1.2.0
pip install paxml==1.2.0
pip install -r requirements.txt
pip install tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl #(或 tensorflow-cpu==2.13.1)
pip install torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl

数据集

基准测试数据集运行时会gluonts自动下载,长期基准测试数据集可通过scnet或官网地址手动下载(官网地址需要魔法):

下载完成后,将数据解压到datasets目录下,若有自订目录需求,可修改timesfm/experiments/long_horizon_benchmarks/run_eval.py:

DATA_DICT = {
    "ettm2": {
        "boundaries": [34560, 46080, 57600],
        "data_path": "./datasets/ETT-small/ETTm2.csv",  # 修改数据集存放路径
        "freq": "15min",
    },
    ...
}

长期基准测试数据集目录结构如下:

 ── datasets
    │   ├── electricity
    │             ├── electricity.csv
    │   ├── ETT-small
    │             ├── ETTh1.csv
    │             ├── ETTh2.csv
    │             ├── ETTm1.csv
    │             └── ETTm2.csv
    │   ├── exchange_rate
    │             └── exchange_rate.csv
    │   ├── illness
    │             └── national illness.csv
    │   ├── traffic
    │             └── traffic.csv
    │   └── weather
    │             └── weather.csv

训练

官方暂未开放

推理

检查点可通过scnet或以下方式进行下载:

# "model"是存储目录,可自订
cd timesfm
export HF_DATASETS_CACHE="path/timesfm/model"
export HF_ENDPOINT=https://hf-mirror.com  # 设置下载地址
huggingface-cli download --resume-download google/timesfm-1.0-200m --local-dir model

推理运行代码:

# 运行基准测试
# 由于基准测试未提供调用数据集的接口,测试集要手动更改timesfm/experiments/extended_benchmarks/run_timesfm.py:dataset_names内填入所需数据集name(当前仓库已设置测试所需数据,不需额外更改)
cd timesfm
sh train.sh 

# 运行长期基准测试
cd timesfm
sh train_long.sh 

result

精度

测试数据:

1、基准测试:
"ett_small_15min",
"traffic",
"m3_quarterly",
"m3_yearly",
"tourism_yearly"
2、长期基准测试:
"etth1" --预测长度pred_len:96 192 336
"ettm1" --预测长度pred_len:96 192 336

根据测试结果情况填写表格,表格内的性能指标取在各数据集上的均值:

基准测试:

devicemaemasescaled_crpssmapetime
DCU K10016779.150571.648400080.1064851430.0967745731.639674425
GPU A80016780.926721.6481671810.1064507160.0967363711.473794842

长期基准测试:

devicewapesmapetime
DCU K1000.7428214950.5154999992.125586112
GPU A8000.7428296980.5153457751.908264995

应用场景

算法类别

时序预测

热点应用行业

交通,零售,金融,气象

预训练权重

源码仓库及问题反馈

参考资料

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

技术瘾君子1573

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值