TimesFM(Time Series Foundation Model)安装(2)

TimesFM(Time Series Foundation Model)安装简介

readme

TimesFM(Time Series Foundation Model)安装简介(1)-优快云博客icon-default.png?t=O83Ahttps://blog.youkuaiyun.com/chenchihwen/article/details/144359861?spm=1001.2014.3001.5501

在Window 上安装并执行出现报错,

{
    "name": "TypeError",
    "message": "TimesFmBase.__init__() got an unexpected keyword argument 'context_len'",
    "stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 8
      3 from jax._src import config
      4 config.update(
      5     \"jax_platforms\", {\"cpu\": \"cpu\", \"gpu\": \"cuda\", \"tpu\": \"\"}[timesfm_backend]
      6 )
----> 8 model = timesfm.TimesFm(
      9     context_len=512,
     10     horizon_len=128,
     11     input_patch_len=32,
     12     output_patch_len=128,
     13     num_layers=20,
     14     model_dims=1280,
     15     backend=timesfm_backend,
     16 )
     17 model.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")

TypeError: TimesFmBase.__init__() got an unexpected keyword argument 'context_len'"
}

决定在 小红帽ubuntu UBUNTU

安装 timesFM

在 ide.cloud.tencent.com 的环境上进行安装 环境 

We recommend at least 16GB RAM to load TimesFM dependencies.

慎选环境,确保>16G

安装Conda  3.10 python

重要步骤 安装  pyenv and poetry

## Installation

### Local installation using poetry

We will be using `pyenv` and `poetry`. In order to set these things up please follow the instructions [here](https://substack.com/home/post/p-148747960?r=28a5lx&utm_campaign=post&utm_medium=web). Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:

确认已经完成安装

(base) root@VM-0-170-ubuntu:/workspace/timesfm# pyenv --version
pyenv 2.4.22
(base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry --version
Poetry (version 1.8.5)

这里安装完需要设置环境变量,如果不能看到 version 版本时

Add `export PATH="/root/.local/bin:$PATH"` to your shell configuration file

要修改环境变量

nano ~/.bash_profile

nano ~/.bash_profile 执行后添加环境变量,在 bash_profile里应该是这个参数

/root/.bash_profile                                                                                              
export PYENV_ROOT="$HOME/.pyenv"
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"
export PATH="/root/.local/bin:$PATH"

修改后要执行 ctrl-O  : write out

再执行 ctrl-X :Exit

刷新环境变量,这时后才会生效,执行以下代码

source ~/.bash_profile

克隆timesFM

git clone https://github.com/google-research/timesfm/git

git clone https://github.com/google-research/timesfm.git
Cloning into 'timesfm'...
remote: Enumerating objects: 667, done.
remote: Counting objects: 100% (665/665), done.
remote: Compressing objects: 100% (316/316), done.
remote: Total 667 (delta 353), reused 568 (delta 306), pack-reused 2 (from 1)
Receiving objects: 100% (667/667), 1.94 MiB | 3.76 MiB/s, done.
Resolving deltas: 100% (353/353), done.

在 timesfm 下找到 pyproject.toml 在最后面添加 aliyun 的source,不然安装不起来

[[tool.poetry.source]]
name = "aliyun"
url = "https://mirrors.aliyun.com/pypi/simple/"
priority = "primary"

 Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:

pyenv install 3.10
pyenv install 3.11
pyenv versions # to list the versions available (lets assume the versions are 3.10.15 and 3.11.10)

For PAX version installation do the following.

在 timesfm git 目录下 执行

pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install -E  pax

在 timesfm git 目录下 执行

➜  timesfm git:(master) pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install -E  pax
pyenv: version `3.10.15' not installed

Could not find the python executable python3.10
Creating virtualenv timesfm-jW3uZHTw-py3.11 in /root/.cache/pypoetry/virtualenvs
Updating dependencies
Resolving dependencies... (378.2s)

过程中还要安装numpy

根据 readme 

Please look into the README files in the respective benchmark directories within `experiments/` for instructions for running TimesFM on the respective benchmarks.

## Running TimesFM on the benchmark

We need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,

```
poetry add git+https://github.com/awslabs/gluon-ts.git
poetry lock
poetry install --only <pax or pytorch>
```

To run the timesfm on the benchmark do:

```
poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m(-pytorch) --backend="gpu"
```

 执行 poetry shell 

(base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry shell
Spawning shell within /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10
. /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/bin/activate
bash: __vsc_prompt_cmd_original: command not found
(base) root@VM-0-170-ubuntu:/workspace/timesfm# . /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/bin/activate
bash: __vsc_prompt_cmd_original: command not found
(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm# 

执行命令 poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m

在成功执行之前,还需安装 jax, 

run_timesfm.py 的代码如下

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for timesfm."""

import os
import sys
import time

from absl import flags
import numpy as np
import pandas as pd
import timesfm

from .utils import ExperimentHandler

dataset_names = [
    "m1_monthly",
    "m1_quarterly",
    "m1_yearly",
    "m3_monthly",
    "m3_other",
    "m3_quarterly",
    "m3_yearly",
    "m4_quarterly",
    "m4_yearly",
    "tourism_monthly",
    "tourism_quarterly",
    "tourism_yearly",
    "nn5_daily_without_missing",
    "m5",
    "nn5_weekly",
    "traffic",
    "weather",
    "australian_electricity_demand",
    "car_parts_without_missing",
    "cif_2016",
    "covid_deaths",
    "ercot",
    "ett_small_15min",
    "ett_small_1h",
    "exchange_rate",
    "fred_md",
    "hospital",
]

context_dict = {
    "cif_2016": 32,
    "tourism_yearly": 64,
    "covid_deaths": 64,
    "tourism_quarterly": 64,
    "tourism_monthly": 64,
    "m1_monthly": 64,
    "m1_quarterly": 64,
    "m1_yearly": 64,
    "m3_monthly": 64,
    "m3_other": 64,
    "m3_quarterly": 64,
    "m3_yearly": 64,
    "m4_quarterly": 64,
    "m4_yearly": 64,
}

_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-1.0-200m",
                                  "Path to model")
_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size")
_HORIZON = flags.DEFINE_integer("horizon", 128, "Horizon")
_BACKEND = flags.DEFINE_string("backend", "gpu", "Backend")
_NUM_JOBS = flags.DEFINE_integer("num_jobs", 1, "Number of jobs")
_SAVE_DIR = flags.DEFINE_string("save_dir", "./results", "Save directory")

QUANTILES = list(np.arange(1, 10) / 10.0)


def main():
  results_list = []
  tfm = timesfm.TimesFm(
      hparams=timesfm.TimesFmHparams(
          backend=_BACKEND.value,
          per_core_batch_size=_BATCH_SIZE.value,
          horizon_len=_HORIZON.value,
      ),
      checkpoint=timesfm.TimesFmCheckpoint(
          huggingface_repo_id=_MODEL_PATH.value),
  )
  run_id = np.random.randint(100000)
  model_name = "timesfm"
  for dataset in dataset_names:
    print(f"Evaluating model {model_name} on dataset {dataset}", flush=True)
    exp = ExperimentHandler(dataset, quantiles=QUANTILES)

    if dataset in context_dict:
      context_len = context_dict[dataset]
    else:
      context_len = 512
    train_df = exp.train_df
    freq = exp.freq
    init_time = time.time()
    fcsts_df = tfm.forecast_on_df(
        inputs=train_df,
        freq=freq,
        value_name="y",
        model_name=model_name,
        forecast_context_len=context_len,
        num_jobs=_NUM_JOBS.value,
    )
    total_time = time.time() - init_time
    time_df = pd.DataFrame({"time": [total_time], "model": model_name})
    results = exp.evaluate_from_predictions(models=[model_name],
                                            fcsts_df=fcsts_df,
                                            times_df=time_df)
    print(results, flush=True)
    results_list.append(results)
    results_full = pd.concat(results_list)
    save_path = os.path.join(_SAVE_DIR.value, str(run_id))
    print(f"Saving results to {save_path}", flush=True)
    os.makedirs(save_path, exist_ok=True)
    results_full.to_csv(f"{save_path}/results.csv")


if __name__ == "__main__":
  FLAGS = flags.FLAGS
  FLAGS(sys.argv)
  main()

执行结果:

(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m
TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
/root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/lib/python3.10/site-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.
  warnings.warn(
Fetching 5 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 63937.56it/s]
Multiprocessing context has already been set.
Constructing model weights.
Constructed model weights in 2.83 seconds.
Restoring checkpoint from /root/.cache/huggingface/hub/models--google--timesfm-1.0-200m/snapshots/8775f7531211ac864b739fe776b0b255c277e2be/checkpoints.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
ERROR:absl:For checkpoint version > 1.0, we require users to provide
          `train_state_unpadded_shape_dtype_struct` during checkpoint
          saving/restoring, to avoid potential silent bugs when loading
          checkpoints to incompatible unpadded shapes of TrainState.
Restored checkpoint in 1.50 seconds.
Jitting decoding.
Jitted decoding in 20.78 seconds.
Evaluating model timesfm on dataset m1_monthly
/root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/lib/python3.10/site-packages/gluonts/time_feature/seasonality.py:47: FutureWarning: 'M' is deprecated and will be removed in a future version, please use 'ME' instead.
  offset = pd.tseries.frequencies.to_offset(freq)
Multiprocessing context has already been set.
/root/miniforge3/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
Processing dataframe with single process.
Finished preprocessing dataframe.
Finished forecasting.

 预测结果

           dataset       metric    model        value
0  tourism_monthly          mae  timesfm  1970.148438
1  tourism_monthly         mase  timesfm     1.541883
2  tourism_monthly  scaled_crps  timesfm     0.121862
3  tourism_monthly        smape  timesfm     0.101539
4  tourism_monthly         time  timesfm     0.762044

---------------

             dataset       metric    model        value
0  tourism_quarterly          mae  timesfm  7439.246094
1  tourism_quarterly         mase  timesfm     1.731996
2  tourism_quarterly  scaled_crps  timesfm     0.087743
3  tourism_quarterly        smape  timesfm     0.083795
4  tourism_quarterly         time  timesfm     0.839042

-----------------

          dataset       metric    model         value
0  tourism_yearly          mae  timesfm  82434.085938
1  tourism_yearly         mase  timesfm      3.233205
2  tourism_yearly  scaled_crps  timesfm      0.129402
3  tourism_yearly        smape  timesfm      0.181012
4  tourism_yearly         time  timesfm      1.023866

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值