Pytorch学习笔记之LSTM
看了理解LSTM这篇博文,在这里写写自己对LSTM网络的一些认识!。
- RNN
- 网络计算过程
Recurrent Neural Networks
人类并不是每时每刻都从一片空白的大脑开始他们的思考。在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义。我们不会将所有的东西都全部丢弃,然后用空白的大脑进行思考。我们的思想拥有持久性。
传统的神经网络并不能做到这点,看起来也像是一种巨大的弊端。例如,假设你希望对电影中的每个时间点的时间类型进行分类。传统的神经网络应该很难来处理这个问题——使用电影中先前的事件推断后续的事件。
RNN 解决了这个问题。RNN 是包含循环的网络,允许信息的持久化
这是一个经典的RNN的流程图。
1. LSTM网络
经典的LSTM的流程图:
相信大家都看过这个图(盗用别人的图)。
再来一段公式,就是下面的,公式来自Pytorch。
h
t
h_t
ht is the hidden state at time
t
t
t ,
c
t
c_t
ct is the cell state at time
t
t
t ,
x
t
x_t
xt is the input at time
t
t
t,
h
(
t
−
1
)
h_{(t-1)}
h(t−1) is the hidden state of the previous layer at time
t
−
1
t-1
t−1 or the initial hidden state at time
0
0
0 , and
i
t
i_t
it ,
f
t
f_t
ft ,
g
t
g_t
gt ,
o
t
o_t
ot are the input, forget, cell, and output gates, respectively.
σ
\sigma
σ is the sigmoid function.
2. 内部计算分析
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
可以看到参数的大小变成了(4*20,10),是标准RNN的四倍。原因是这里它包括了四个参数矩阵 W i i W_{ii} Wii、 W i f W_{if} Wif、 W i g W_{ig} Wig、 W i o W_{io} Wio,它们每一个都是(20×10),输入的维度大小是(10×1), 这样 i t i_t it , f t f_t ft , g t g_t gt , o t o_t ot 的维度都是(20×1),公式(5)(6)的运算应该是叉积(元素积),这样得到的 c t c_t ct和 h t h_t ht的维度才能是20。
如上图所示hn和cn的最后一维都是20。注意这里的LSTM网络是单向,双向的要*2。蟹蟹!