【用pytorch进行LSTM模型的学习】

本文介绍了使用PyTorch实现LSTM模型对时间序列数据(如航班流量)进行预测的步骤,包括数据读取、观察、预处理、模型构建、训练、保存和验证。关键操作包括数据归一化、异常值处理,以及模型的保存以备后续使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

LSTM模型

LSTM模型长下面这样,主要用在时间序列的预测,具有比RNN较好的性能。原因在于内部增加了很多门,用来控制前序信息的继续、遗忘、更新等,比RNN更好的表达了特征。
在这里插入图片描述

用pytorch,采用LSTM对seaborn数据集做预测

基本步骤

一般而言,进行深度学习的训练与应用包含大概如下步骤

=========工作流程=========
 - 数据读取与基本处理
    * 数据集读取
    * 数据的观察-画图
    * 特殊数据处理-空值、奇异值等
 - 数据集构建
    * 归一化
    * 训练集、验证集、测试集划分    
 - 模型建模
    * 基础模型架构
    * 损失函数
    * 优化器选择
 - 模型训练
    * 模型训练 与各种超参
    * 训练过程观察 
    * 训练中模型保存
    * 模型训练指标记录
 - 测试验证
    * 模型性能验证
    * 结果可视化
    * 测试性能指标记录

下面就流程中的几个重点进行说明

数据的观察

在拿到数据的时候,我们首先要对数据进行观察,观察的方法根据数据的类型略有不同,但是总体可以概括为

  • 肉眼观察:打开数据文件夹或者文件进行查看,比如文件个数有多少个,数据的大小是多少。
  • 数据展示观察:对于一些不好直接观察的,可以通过数据展示看一下,如打印dataframe结构的前几行,可以看到列名等信息,方便数据处理。
  • 画图观察:对于一些时序信息,可以通过作图的方式,看看数据的分布情况,是否有异常点等等。

为什么要对数据进行观察?主要有以下几个原因

  • 获取数据的基本信息,知道我们要处理的数据大概是怎样的。
  • 对原始数据有个感觉,数据的情况可能会影响我们模型的选择。以及模型训练的策略。比如小样本数据,样本数的多少会影响下一步的决策,如是否数据增强,是否迁移等等。
  • 观察到异常情况,如空值,奇异点,为下一步数据处理做准备。

特殊数据处理

机器学习处理的是数据的一般情况,即反映数据的一般规律和一般分布,对于奇异值或者特殊值,机器学习模型没有能力处理或者需要付出很大的代价才能处理。机器学习是帮助我们解决一般问题或者共性问题,对于一些特殊的问题,并不是这个学科的主要研究方向。当然,只有一个方向除外,即异常检测。
一般需要特殊处理的,有空值、错误值、奇异值。基本的处理方式有

  • 删除,即删除特殊值
  • 补全,补全空值
  • 修正,更改错误值

数据归一化

在一般情况下,尤其是时序数据,需要进行归一化,即把数据压缩到0-1之间。目的是使得数据有相同的尺度。例如,在一个数据集中,包含样本的年龄信息,收入信息等,这两个信息的度量尺度是不同的,如果不做归一化,那么由于年龄与收入在数值上相差很大,那么年龄的特征不能在模型中发挥很好的作用。

模型的构建与选择

针对不同的任务选择不同的模型,有pytorch内置了很多基础模型,因此模型结构的构建变得简单容易,需要注意的是模型的输入参数要求以及维度匹配,这就需要我们学习pytorch内置模型的接口函数,做一个合格的调包侠

模型的保存

在训练过程中,模型是不断更新的,每一次迭代后模型的参数就会不同。在这个过程中有必要有条件地保存下当前模型,主要有如下几个用途

  • 防止训练突然崩掉,重新训练浪费资源。在较长时间的训练过程中,由于种种原因,训练可能会崩溃,如突然掉电,机器故障灯,如果没有保存训练过程中的模型,则需要重新训练,那么浪费时间,浪费资源,尤其是接近训练完成的时候发生崩溃,人就更崩溃了。如果保存了模型,那么可以重新加载模型,断点续训练。
  • 根据过程中保存下来的模型,我们可以查看模型演变过程,进行过程的考察。
  • 测试验证用,保存模型,尤其是保存最后的或者最好的模型,在测试验证时,可以直接加载进行验证,不必再次训练

那么模型该如何保存呢? 模型保存的格式:pytorch中最常见的模型保存使用 .pt 或者是 .pth 作为模型文件扩展名。

pytorch模型保存的两种方式:

  • 一种是保存整个模型,
torch.save(model, "my_model.pth") # 保存整个模型` 
  • 另一种是只保存模型的参数,该方法速度快,占用空间少
torch.save(model.state_dict(), "my_model.pth") # 只保存模型的参数

相应的,加载也有两种方式

  • 加载整个模型
new_model = torch.load(PATH) 
  • 先构架模型架构,然后加载参数
new_model = Model()                          
new_model.load_state_dict(torch.load(PATH))   

飞机航班流量预测示例

完整代码如下

# -*- coding: utf-8 -*-
# @Time    : 2023/03/10 10:23
# @Author  : HelloWorld!
# @FileName: seq.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.8

import torch
import torch.nn as nn
import argparse
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math


# 数据读取与基本处理
class LoadData:
    def __init__(self,data_path ):
        self.ori_data = pd.read_csv(data_path)
    def data_observe(self):
        self.ori_data.head()
        self.draw_data(self.ori_data)
    def draw_data(self, data):
        print(data.head())
        fig_size = plt.rcParams["figure.figsize"]
        fig_size[0] = 15
        fig_size[1] = 5
        plt.rcParams["figure.figsize"] = fig_size
        plt.title('Month vs Passenger')
        plt.ylabel(
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值