5分钟上手Time-Series-Library:从模型训练到部署全流程指南
你是否还在为时间序列预测模型的选择而烦恼?面对Autoformer、TimesNet等数十种模型不知如何入手?本文将带你零基础掌握Time-Series-Library(TSLib)的完整使用流程,从环境搭建到模型部署,让你轻松应对长短期预测、异常检测等五大时间序列任务。
1. 项目概述
Time-Series-Library(TSLib)是一个专注于深度学习时间序列分析的开源库,提供了统一的代码框架来评估和开发先进的时间序列模型。该库支持五大主流任务:长短期预测、缺失值填补、异常检测和分类。
官方文档:README.md 社区教程:TimesNet_tutorial.ipynb
2. 环境准备
2.1 安装依赖
首先克隆项目仓库并安装所需依赖:
git clone https://gitcode.com/GitHub_Trending/ti/Time-Series-Library
cd Time-Series-Library
pip install -r requirements.txt
2.2 数据集准备
TSLib支持多种公开数据集,可从Google Drive或百度网盘下载预处理好的数据集,解压后放在./dataset目录下。
3. 模型训练实战
3.1 快速启动训练
TSLib提供了便捷的脚本文件来启动不同任务的训练。以长短期预测任务为例:
# 长期预测
bash ./scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh
# 短期预测
bash ./scripts/short_term_forecast/TimesNet_M4.sh
# 异常检测
bash ./scripts/anomaly_detection/PSM/TimesNet.sh
3.2 自定义训练参数
通过修改run.py文件中的参数,可以自定义训练过程。关键参数包括:
# 模型参数
parser.add_argument('--model', type=str, default='TimesNet', help='模型名称')
parser.add_argument('--seq_len', type=int, default=96, help='输入序列长度')
parser.add_argument('--pred_len', type=int, default=96, help='预测序列长度')
# 训练参数
parser.add_argument('--train_epochs', type=int, default=10, help='训练轮数')
parser.add_argument('--batch_size', type=int, default=32, help='批次大小')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='学习率')
训练配置源码:run.py
3.3 训练流程解析
TSLib的训练流程主要在exp目录下的文件中实现,以长期预测任务为例,核心代码在exp/exp_long_term_forecasting.py中:
def train(self, setting):
# 数据准备
train_data, train_loader = self._get_data(flag='train')
vali_data, vali_loader = self._get_data(flag='val')
# 模型初始化
model_optim = self._select_optimizer()
criterion = self._select_criterion()
# 训练循环
for epoch in range(self.args.train_epochs):
self.model.train()
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
model_optim.zero_grad()
batch_x = batch_x.float().to(self.device)
batch_y = batch_y.float().to(self.device)
# 前向传播
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
# 计算损失
loss = criterion(outputs, batch_y)
# 反向传播
loss.backward()
model_optim.step()
训练逻辑源码:exp/exp_long_term_forecasting.py
4. 模型架构解析
4.1 核心模型组件
TSLib包含多种先进的时间序列模型,如TimesNet、Autoformer、Informer等。以TimesNet为例,其核心组件是TimesBlock,通过FFT提取周期特征并使用2D卷积捕获时间依赖关系:
class TimesBlock(nn.Module):
def __init__(self, configs):
super(TimesBlock, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.k = configs.top_k
# 卷积块
self.conv = nn.Sequential(
Inception_Block_V1(configs.d_model, configs.d_ff, num_kernels=configs.num_kernels),
nn.GELU(),
Inception_Block_V1(configs.d_ff, configs.d_model, num_kernels=configs.num_kernels)
)
def forward(self, x):
# FFT获取周期特征
period_list, period_weight = FFT_for_Period(x, self.k)
# 周期特征处理与卷积
for i in range(self.k):
period = period_list[i]
# 填充与重塑
out = out.reshape(B, length // period, period, N).permute(0, 3, 1, 2).contiguous()
# 2D卷积
out = self.conv(out)
# 重塑回原形状
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, :(self.seq_len + self.pred_len), :])
# 加权聚合
res = torch.sum(res * period_weight, -1)
return res + x # 残差连接
TimesNet模型源码:models/TimesNet.py
4.2 支持的模型列表
TSLib支持多种先进的时间序列模型,主要模型如下表所示:
| 模型名称 | 任务类型 | 特点 |
|---|---|---|
| TimesNet | 全任务 | 基于2D卷积捕获时间序列周期特征 |
| Autoformer | 预测、填补 | 基于自相关机制的分解Transformer |
| Informer | 预测 | 高效Transformer,适用于长序列 |
| DLinear | 预测 | 简单有效的线性模型,速度快 |
| Mamba | 全任务 | 基于状态空间模型的高效序列模型 |
模型注册源码:exp/exp_basic.py
5. 模型评估与部署
5.1 模型评估指标
TSLib提供了多种评估指标,在utils/metrics.py中实现,包括MSE、MAE、RMSE等:
def MSE(pred, true):
return torch.mean((pred - true) ** 2)
def MAE(pred, true):
return torch.mean(torch.abs(pred - true))
def RMSE(pred, true):
return torch.sqrt(torch.mean((pred - true) ** 2))
评估指标源码:utils/metrics.py
5.2 模型保存与加载
训练完成后,模型会保存在checkpoints目录下,可以通过以下代码加载模型:
def load_model(self, path):
model_dict = torch.load(path)
self.model.load_state_dict(model_dict)
return self.model
5.3 部署流程
- 导出模型权重
- 编写推理脚本
- 集成到生产环境
推理示例代码:
import torch
from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast
import argparse
# 加载模型配置
parser = argparse.ArgumentParser(description='TSForecast')
# 添加必要的参数...
args = parser.parse_args()
# 初始化模型
exp = Exp_Long_Term_Forecast(args)
model = exp.model
# 加载权重
model.load_state_dict(torch.load('checkpoints/your_model_path.pth'))
# 推理
def predict(data):
model.eval()
with torch.no_grad():
pred = model(data)
return pred
# 示例数据
data = torch.randn(1, 96, 7) # (batch_size, seq_len, features)
result = predict(data)
print(result.shape) # (1, 96, 7)
6. 自定义模型开发
6.1 添加新模型步骤
- 在
models目录下创建新模型文件,如NewModel.py - 实现模型类,继承nn.Module
- 在
exp/exp_basic.py的model_dict中注册模型 - 创建对应的训练脚本,放在
scripts目录下
详细教程:tutorial/TimesNet_tutorial.ipynb
6.2 模型开发示例
以添加一个简单的LSTM模型为例:
# models/LSTM.py
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, configs):
super(LSTM, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.input_size = configs.enc_in
self.hidden_size = configs.d_model
self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
self.fc = nn.Linear(self.hidden_size, configs.c_out)
def forward(self, x):
# x: [B, T, D]
out, _ = self.lstm(x)
out = self.fc(out[:, -self.pred_len:, :])
return out
然后在exp/exp_basic.py中注册模型:
self.model_dict = {
# ...现有模型
'LSTM': LSTM,
}
7. 高级功能与扩展
7.1 数据增强
TSLib支持多种时间序列数据增强方法,在utils/augmentation.py中实现,包括抖动、缩放、排列等:
def jitter(x, sigma=0.05):
# 抖动增强
return x + torch.normal(0, sigma, size=x.shape).to(x.device)
def scaling(x, sigma=0.1):
# 缩放增强
factor = torch.normal(1, sigma, size=(x.shape[0], x.shape[2])).to(x.device)
return x * factor.unsqueeze(1)
数据增强源码:utils/augmentation.py
7.2 特征工程
utils/timefeatures.py提供了时间特征编码功能,支持多种时间频率:
def time_features(dates, freq='h'):
# 提取时间特征
dates['month'] = dates.date.apply(lambda row: row.month, 1)
dates['day'] = dates.date.apply(lambda row: row.day, 1)
dates['hour'] = dates.date.apply(lambda row: row.hour, 1)
# ...其他特征
return dates
时间特征源码:utils/timefeatures.py
8. 总结与展望
Time-Series-Library提供了一个统一、高效的时间序列模型开发和评估框架,支持多种先进模型和任务类型。通过本文的介绍,你已经掌握了从环境搭建到模型部署的完整流程。
未来,TSLib将继续集成更多先进模型,优化性能,并扩展到更多时间序列应用场景。欢迎通过CONTRIBUTING.md参与项目贡献。
若有任何问题,可通过项目Issues进行反馈。
本文档基于Time-Series-Library最新版本编写,建议定期查看项目更新以获取最新功能。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考




