使用Haiku框架实现LSTM时间序列预测

使用Haiku框架实现LSTM时间序列预测

dm-haiku JAX-based neural network library dm-haiku 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku

概述

本文将介绍如何使用DeepMind开发的Haiku框架构建和训练LSTM模型,用于时间序列预测任务。Haiku是一个基于JAX的神经网络库,它提供了简洁的API来构建复杂的神经网络模型。

环境准备

在开始之前,我们需要安装必要的Python包:

pip install dm-haiku optax

然后导入所需的库:

import math
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import numpy as np

数据生成

我们将使用正弦波作为我们的时间序列数据。每个正弦波有不同的相位偏移,模型的任务是根据前面的值预测下一个值。

def sine_seq(phase: float, seq_len: int, samples_per_cycle: int):
    """生成正弦波序列"""
    t = np.arange(seq_len + 1) * (2 * math.pi / samples_per_cycle)
    t = t.reshape([-1, 1]) + phase
    sine_t = np.sin(t)
    return sine_t[:-1, :], sine_t[1:, :]

我们创建了一个Dataset类来方便地批量获取数据:

class Dataset:
    """数据集迭代器,每次返回一个批次的数据"""
    
    def __init__(self, xy: tuple, batch_size: int):
        self._x, self._y = xy
        self._batch_size = batch_size
        self._length = self._x.shape[1]
        self._idx = 0
        
    def __next__(self):
        """获取下一个批次的数据"""
        start = self._idx
        end = start + self._batch_size
        x, y = self._x[:, start:end], self._y[:, start:end]
        if end >= self._length:
            end = end % self._length
        self._idx = end
        return x, y

LSTM模型构建

使用Haiku构建LSTM模型非常简单。我们定义一个函数来展开LSTM网络:

def unroll_net(seqs: jax.Array):
    """展开LSTM网络处理序列数据"""
    core = hk.LSTM(32)  # 32个隐藏单元
    batch_size = seqs.shape[1]
    outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
    return hk.BatchApply(hk.Linear(1))(outs), state  # 输出层

使用hk.transform将我们的函数转换为纯函数:

model = hk.transform(unroll_net)

训练过程

我们使用Adam优化器和均方误差损失函数来训练模型:

def train_model(train_ds: Dataset, valid_ds: Dataset) -> hk.Params:
    """训练模型并返回最终参数"""
    rng = jax.random.PRNGKey(428)
    opt = optax.adam(1e-3)  # 学习率0.001
    
    @jax.jit
    def loss(params, x, y):
        """计算均方误差损失"""
        pred, _ = model.apply(params, None, x)
        return jnp.mean(jnp.square(pred - y))
    
    @jax.jit
    def update(step, params, opt_state, x, y):
        """单步参数更新"""
        l, grads = jax.value_and_grad(loss)(params, x, y)
        grads, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, grads)
        return l, params, opt_state
    
    # 初始化参数
    sample_x, _ = next(train_ds)
    params = model.init(rng, sample_x)
    opt_state = opt.init(params)
    
    # 训练循环
    for step in range(2001):
        x, y = next(train_ds)
        train_loss, params, opt_state = update(step, params, opt_state, x, y)
        
        if step % 100 == 0:
            x, y = next(valid_ds)
            print(f"Step {step}: valid loss {loss(params, x, y)}")
    
    return params

模型预测

训练完成后,我们可以使用模型进行预测。有两种预测方式:

  1. 使用真实值作为输入:在每一步都使用真实值作为输入
  2. 自回归预测:使用模型自己的预测作为下一步的输入

使用真实值预测

# 获取验证集样本
sample_x, _ = next(valid_ds)
sample_x = sample_x[:, :1]  # 取batch_size=1

# 使用真实值作为输入进行预测
predicted, _ = model.apply(trained_params, None, sample_x)

自回归预测

自回归预测更接近实际应用场景,但计算效率较低:

def autoregressive_predict(trained_params, context, seq_len):
    """自回归预测"""
    ar_outs = []
    context = jax.device_put(context)
    for _ in range(seq_len - context.shape[0]):
        full_context = jnp.concatenate([context] + ar_outs)
        outs, _ = jax.jit(model.apply)(trained_params, None, full_context)
        ar_outs.append(outs[-1:])
    return outs

为了提高效率,我们可以专门为自回归预测设计一个更快的Haiku函数:

def fast_autoregressive_predict_fn(context, seq_len):
    """高效的自回归预测函数"""
    core = hk.LSTM(32)
    dense = hk.Linear(1)
    state = core.initial_state(context.shape[1])
    
    # 使用真实上下文初始化状态
    context_outs, state = hk.dynamic_unroll(core, context, state)
    context_outs = hk.BatchApply(dense)(context_outs)
    
    # 自回归预测
    ar_outs = []
    x = context_outs[-1]
    for _ in range(seq_len - context.shape[0]):
        x, state = core(x, state)
        x = dense(x)
        ar_outs.append(x)
    return jnp.concatenate([context_outs, jnp.stack(ar_outs)])

性能比较

两种自回归预测方法在性能上有显著差异:

%timeit autoregressive_predict(trained_params, context, SEQ_LEN)
%timeit fast_ar_predict(trained_params, None, context, SEQ_LEN)

高效版本通常比原始版本快一个数量级以上。

总结

本文展示了如何使用Haiku框架构建和训练LSTM模型进行时间序列预测。关键点包括:

  1. Haiku提供了简洁的API来构建复杂的神经网络
  2. 使用hk.transform可以将模型转换为纯函数
  3. 自回归预测可以模拟实际应用场景
  4. 专门优化的预测函数可以显著提高性能

Haiku与JAX的结合为深度学习研究提供了强大的工具,特别是在需要高性能计算的场景下。通过本教程,读者可以掌握使用Haiku构建循环神经网络的基本方法。

dm-haiku JAX-based neural network library dm-haiku 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

俞淑瑜Sally

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值