背景介绍
时间序列的处理使用RNN更为有效。但RNN中的一些参数理解起来与CNN差别很大,这篇文章主要梳理一下RNN中LSTM架构的几个关键参数以及如何理解这些参数。
以pytorch为例,我们首先看一下LSTM网络的构建过程
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM(
input_size=1,
hidden_size=64,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(64, 2)
def forward(self, x):
r_out, (h_n, h_c) = self.rnn(x, None)
out = self.out(h_n[0])
return out
我们使用一个比较简单理解的例子来解释一下这几个主要参数的含义,比如我们用30天的买东西的数据来预测第31天的,每天采集一组数据,这组数据可以表示为
day1 : {面包:5个,泡面3个,火腿肠2个,卤蛋2个,可乐2个}
day2 : {面包:3个,泡面1个,火腿肠2个,卤蛋1个,可乐1个}
以此类推
这里我们可以看到