TimesFM(Time Series Foundation Model)安装简介
readme
在Window 上安装并执行出现报错,
{
"name": "TypeError",
"message": "TimesFmBase.__init__() got an unexpected keyword argument 'context_len'",
"stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 8
3 from jax._src import config
4 config.update(
5 \"jax_platforms\", {\"cpu\": \"cpu\", \"gpu\": \"cuda\", \"tpu\": \"\"}[timesfm_backend]
6 )
----> 8 model = timesfm.TimesFm(
9 context_len=512,
10 horizon_len=128,
11 input_patch_len=32,
12 output_patch_len=128,
13 num_layers=20,
14 model_dims=1280,
15 backend=timesfm_backend,
16 )
17 model.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")
TypeError: TimesFmBase.__init__() got an unexpected keyword argument 'context_len'"
}
决定在 小红帽ubuntu UBUNTU
安装 timesFM
在 ide.cloud.tencent.com 的环境上进行安装 环境
We recommend at least 16GB RAM to load TimesFM dependencies.
慎选环境,确保>16G
安装Conda 3.10 python
重要步骤 安装 pyenv and poetry
## Installation
### Local installation using poetry
We will be using `pyenv` and `poetry`. In order to set these things up please follow the instructions [here](https://substack.com/home/post/p-148747960?r=28a5lx&utm_campaign=post&utm_medium=web). Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:
确认已经完成安装
(base) root@VM-0-170-ubuntu:/workspace/timesfm# pyenv --version
pyenv 2.4.22
(base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry --version
Poetry (version 1.8.5)
这里安装完需要设置环境变量,如果不能看到 version 版本时
Add `export PATH="/root/.local/bin:$PATH"` to your shell configuration file
要修改环境变量
nano ~/.bash_profile
nano ~/.bash_profile 执行后添加环境变量,在 bash_profile里应该是这个参数
/root/.bash_profile
export PYENV_ROOT="$HOME/.pyenv"
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"
export PATH="/root/.local/bin:$PATH"
修改后要执行 ctrl-O : write out
再执行 ctrl-X :Exit
刷新环境变量,这时后才会生效,执行以下代码
source ~/.bash_profile
克隆timesFM
git clone https://github.com/google-research/timesfm/git
git clone https://github.com/google-research/timesfm.git
Cloning into 'timesfm'...
remote: Enumerating objects: 667, done.
remote: Counting objects: 100% (665/665), done.
remote: Compressing objects: 100% (316/316), done.
remote: Total 667 (delta 353), reused 568 (delta 306), pack-reused 2 (from 1)
Receiving objects: 100% (667/667), 1.94 MiB | 3.76 MiB/s, done.
Resolving deltas: 100% (353/353), done.
在 timesfm 下找到 pyproject.toml 在最后面添加 aliyun 的source,不然安装不起来
[[tool.poetry.source]]
name = "aliyun"
url = "https://mirrors.aliyun.com/pypi/simple/"
priority = "primary"
Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:
pyenv install 3.10
pyenv install 3.11
pyenv versions # to list the versions available (lets assume the versions are 3.10.15 and 3.11.10)
For PAX version installation do the following.
在 timesfm git 目录下 执行
pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install -E pax
在 timesfm git 目录下 执行
➜ timesfm git:(master) pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install -E pax
pyenv: version `3.10.15' not installed
Could not find the python executable python3.10
Creating virtualenv timesfm-jW3uZHTw-py3.11 in /root/.cache/pypoetry/virtualenvs
Updating dependencies
Resolving dependencies... (378.2s)
过程中还要安装numpy
根据 readme
Please look into the README files in the respective benchmark directories within `experiments/` for instructions for running TimesFM on the respective benchmarks.
## Running TimesFM on the benchmark
We need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,
```
poetry add git+https://github.com/awslabs/gluon-ts.git
poetry lock
poetry install --only <pax or pytorch>
```To run the timesfm on the benchmark do:
```
poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m(-pytorch) --backend="gpu"
```
执行 poetry shell
(base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry shell
Spawning shell within /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10
. /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/bin/activate
bash: __vsc_prompt_cmd_original: command not found
(base) root@VM-0-170-ubuntu:/workspace/timesfm# . /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/bin/activate
bash: __vsc_prompt_cmd_original: command not found
(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm#
执行命令 poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m
在成功执行之前,还需安装 jax,
run_timesfm.py 的代码如下
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for timesfm."""
import os
import sys
import time
from absl import flags
import numpy as np
import pandas as pd
import timesfm
from .utils import ExperimentHandler
dataset_names = [
"m1_monthly",
"m1_quarterly",
"m1_yearly",
"m3_monthly",
"m3_other",
"m3_quarterly",
"m3_yearly",
"m4_quarterly",
"m4_yearly",
"tourism_monthly",
"tourism_quarterly",
"tourism_yearly",
"nn5_daily_without_missing",
"m5",
"nn5_weekly",
"traffic",
"weather",
"australian_electricity_demand",
"car_parts_without_missing",
"cif_2016",
"covid_deaths",
"ercot",
"ett_small_15min",
"ett_small_1h",
"exchange_rate",
"fred_md",
"hospital",
]
context_dict = {
"cif_2016": 32,
"tourism_yearly": 64,
"covid_deaths": 64,
"tourism_quarterly": 64,
"tourism_monthly": 64,
"m1_monthly": 64,
"m1_quarterly": 64,
"m1_yearly": 64,
"m3_monthly": 64,
"m3_other": 64,
"m3_quarterly": 64,
"m3_yearly": 64,
"m4_quarterly": 64,
"m4_yearly": 64,
}
_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-1.0-200m",
"Path to model")
_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size")
_HORIZON = flags.DEFINE_integer("horizon", 128, "Horizon")
_BACKEND = flags.DEFINE_string("backend", "gpu", "Backend")
_NUM_JOBS = flags.DEFINE_integer("num_jobs", 1, "Number of jobs")
_SAVE_DIR = flags.DEFINE_string("save_dir", "./results", "Save directory")
QUANTILES = list(np.arange(1, 10) / 10.0)
def main():
results_list = []
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend=_BACKEND.value,
per_core_batch_size=_BATCH_SIZE.value,
horizon_len=_HORIZON.value,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id=_MODEL_PATH.value),
)
run_id = np.random.randint(100000)
model_name = "timesfm"
for dataset in dataset_names:
print(f"Evaluating model {model_name} on dataset {dataset}", flush=True)
exp = ExperimentHandler(dataset, quantiles=QUANTILES)
if dataset in context_dict:
context_len = context_dict[dataset]
else:
context_len = 512
train_df = exp.train_df
freq = exp.freq
init_time = time.time()
fcsts_df = tfm.forecast_on_df(
inputs=train_df,
freq=freq,
value_name="y",
model_name=model_name,
forecast_context_len=context_len,
num_jobs=_NUM_JOBS.value,
)
total_time = time.time() - init_time
time_df = pd.DataFrame({"time": [total_time], "model": model_name})
results = exp.evaluate_from_predictions(models=[model_name],
fcsts_df=fcsts_df,
times_df=time_df)
print(results, flush=True)
results_list.append(results)
results_full = pd.concat(results_list)
save_path = os.path.join(_SAVE_DIR.value, str(run_id))
print(f"Saving results to {save_path}", flush=True)
os.makedirs(save_path, exist_ok=True)
results_full.to_csv(f"{save_path}/results.csv")
if __name__ == "__main__":
FLAGS = flags.FLAGS
FLAGS(sys.argv)
main()
执行结果:
(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m
TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
/root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/lib/python3.10/site-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.
warnings.warn(
Fetching 5 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 63937.56it/s]
Multiprocessing context has already been set.
Constructing model weights.
Constructed model weights in 2.83 seconds.
Restoring checkpoint from /root/.cache/huggingface/hub/models--google--timesfm-1.0-200m/snapshots/8775f7531211ac864b739fe776b0b255c277e2be/checkpoints.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
ERROR:absl:For checkpoint version > 1.0, we require users to provide
`train_state_unpadded_shape_dtype_struct` during checkpoint
saving/restoring, to avoid potential silent bugs when loading
checkpoints to incompatible unpadded shapes of TrainState.
Restored checkpoint in 1.50 seconds.
Jitting decoding.
Jitted decoding in 20.78 seconds.
Evaluating model timesfm on dataset m1_monthly
/root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/lib/python3.10/site-packages/gluonts/time_feature/seasonality.py:47: FutureWarning: 'M' is deprecated and will be removed in a future version, please use 'ME' instead.
offset = pd.tseries.frequencies.to_offset(freq)
Multiprocessing context has already been set.
/root/miniforge3/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Processing dataframe with single process.
Finished preprocessing dataframe.
Finished forecasting.
预测结果
dataset metric model value
0 tourism_monthly mae timesfm 1970.148438
1 tourism_monthly mase timesfm 1.541883
2 tourism_monthly scaled_crps timesfm 0.121862
3 tourism_monthly smape timesfm 0.101539
4 tourism_monthly time timesfm 0.762044
---------------
dataset metric model value
0 tourism_quarterly mae timesfm 7439.246094
1 tourism_quarterly mase timesfm 1.731996
2 tourism_quarterly scaled_crps timesfm 0.087743
3 tourism_quarterly smape timesfm 0.083795
4 tourism_quarterly time timesfm 0.839042
-----------------
dataset metric model value
0 tourism_yearly mae timesfm 82434.085938
1 tourism_yearly mase timesfm 3.233205
2 tourism_yearly scaled_crps timesfm 0.129402
3 tourism_yearly smape timesfm 0.181012
4 tourism_yearly time timesfm 1.023866