使用Haiku框架实现LSTM时间序列预测
dm-haiku JAX-based neural network library 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku
概述
本文将介绍如何使用DeepMind开发的Haiku框架构建和训练LSTM模型,用于时间序列预测任务。Haiku是一个基于JAX的神经网络库,它提供了简洁的API来构建复杂的神经网络模型。
环境准备
在开始之前,我们需要安装必要的Python包:
pip install dm-haiku optax
然后导入所需的库:
import math
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import numpy as np
数据生成
我们将使用正弦波作为我们的时间序列数据。每个正弦波有不同的相位偏移,模型的任务是根据前面的值预测下一个值。
def sine_seq(phase: float, seq_len: int, samples_per_cycle: int):
"""生成正弦波序列"""
t = np.arange(seq_len + 1) * (2 * math.pi / samples_per_cycle)
t = t.reshape([-1, 1]) + phase
sine_t = np.sin(t)
return sine_t[:-1, :], sine_t[1:, :]
我们创建了一个Dataset类来方便地批量获取数据:
class Dataset:
"""数据集迭代器,每次返回一个批次的数据"""
def __init__(self, xy: tuple, batch_size: int):
self._x, self._y = xy
self._batch_size = batch_size
self._length = self._x.shape[1]
self._idx = 0
def __next__(self):
"""获取下一个批次的数据"""
start = self._idx
end = start + self._batch_size
x, y = self._x[:, start:end], self._y[:, start:end]
if end >= self._length:
end = end % self._length
self._idx = end
return x, y
LSTM模型构建
使用Haiku构建LSTM模型非常简单。我们定义一个函数来展开LSTM网络:
def unroll_net(seqs: jax.Array):
"""展开LSTM网络处理序列数据"""
core = hk.LSTM(32) # 32个隐藏单元
batch_size = seqs.shape[1]
outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
return hk.BatchApply(hk.Linear(1))(outs), state # 输出层
使用hk.transform
将我们的函数转换为纯函数:
model = hk.transform(unroll_net)
训练过程
我们使用Adam优化器和均方误差损失函数来训练模型:
def train_model(train_ds: Dataset, valid_ds: Dataset) -> hk.Params:
"""训练模型并返回最终参数"""
rng = jax.random.PRNGKey(428)
opt = optax.adam(1e-3) # 学习率0.001
@jax.jit
def loss(params, x, y):
"""计算均方误差损失"""
pred, _ = model.apply(params, None, x)
return jnp.mean(jnp.square(pred - y))
@jax.jit
def update(step, params, opt_state, x, y):
"""单步参数更新"""
l, grads = jax.value_and_grad(loss)(params, x, y)
grads, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, grads)
return l, params, opt_state
# 初始化参数
sample_x, _ = next(train_ds)
params = model.init(rng, sample_x)
opt_state = opt.init(params)
# 训练循环
for step in range(2001):
x, y = next(train_ds)
train_loss, params, opt_state = update(step, params, opt_state, x, y)
if step % 100 == 0:
x, y = next(valid_ds)
print(f"Step {step}: valid loss {loss(params, x, y)}")
return params
模型预测
训练完成后,我们可以使用模型进行预测。有两种预测方式:
- 使用真实值作为输入:在每一步都使用真实值作为输入
- 自回归预测:使用模型自己的预测作为下一步的输入
使用真实值预测
# 获取验证集样本
sample_x, _ = next(valid_ds)
sample_x = sample_x[:, :1] # 取batch_size=1
# 使用真实值作为输入进行预测
predicted, _ = model.apply(trained_params, None, sample_x)
自回归预测
自回归预测更接近实际应用场景,但计算效率较低:
def autoregressive_predict(trained_params, context, seq_len):
"""自回归预测"""
ar_outs = []
context = jax.device_put(context)
for _ in range(seq_len - context.shape[0]):
full_context = jnp.concatenate([context] + ar_outs)
outs, _ = jax.jit(model.apply)(trained_params, None, full_context)
ar_outs.append(outs[-1:])
return outs
为了提高效率,我们可以专门为自回归预测设计一个更快的Haiku函数:
def fast_autoregressive_predict_fn(context, seq_len):
"""高效的自回归预测函数"""
core = hk.LSTM(32)
dense = hk.Linear(1)
state = core.initial_state(context.shape[1])
# 使用真实上下文初始化状态
context_outs, state = hk.dynamic_unroll(core, context, state)
context_outs = hk.BatchApply(dense)(context_outs)
# 自回归预测
ar_outs = []
x = context_outs[-1]
for _ in range(seq_len - context.shape[0]):
x, state = core(x, state)
x = dense(x)
ar_outs.append(x)
return jnp.concatenate([context_outs, jnp.stack(ar_outs)])
性能比较
两种自回归预测方法在性能上有显著差异:
%timeit autoregressive_predict(trained_params, context, SEQ_LEN)
%timeit fast_ar_predict(trained_params, None, context, SEQ_LEN)
高效版本通常比原始版本快一个数量级以上。
总结
本文展示了如何使用Haiku框架构建和训练LSTM模型进行时间序列预测。关键点包括:
- Haiku提供了简洁的API来构建复杂的神经网络
- 使用
hk.transform
可以将模型转换为纯函数 - 自回归预测可以模拟实际应用场景
- 专门优化的预测函数可以显著提高性能
Haiku与JAX的结合为深度学习研究提供了强大的工具,特别是在需要高性能计算的场景下。通过本教程,读者可以掌握使用Haiku构建循环神经网络的基本方法。
dm-haiku JAX-based neural network library 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考