Pytorch 小记 第十一回:RNN循环神经网络模型代码
本次小记,提供了一份基于pytorch的RNN循环神经网络模型的代码。代码是基于RNN模型来完成对股票收盘价格的预测。除此之外,对代码中不容易理解的部分进行了讲解。
本代码的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0,d2l的版本是1.0.3
文章目录
一、程序代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
import time
data_csv = pd.read_csv('./stock_data.csv', usecols=[3])
# 数据预处理
data_csv = data_csv.dropna()
dataset = data_csv.values
dataset = dataset.astype('float32')
max_value = np.max(dataset)
min_value = np.min(dataset)
interval = max_value - min_value
dataset = list(map(lambda x: x / interval, dataset))
# 数据划分
def optimize_dataset(dataset, primitive=3):
dataX, dataY = [], []
for i in range(len(dataset) - primitive):
a = dataset[i:(i + primitive)]
dataX.append(a)
dataY.append(dataset[i + primitive])
return np.array(dataX), np.array(dataY)
data_X, data_Y = optimize_dataset(dataset)
# 划分训练集和测试集
train_size = int