引言
LSTM应该说是每一个做机器学习的人都绕不开的东西,它的结构看起来复杂,但是充分体现着人脑在记忆过程中的特征,下面本文将介绍一下LSTM的结构以及pytorch的用法。
LSTM结构
总体结构
首先,LSTM主要用来处理带有时序信息的数据,包括视频、句子,它将人脑的对于不同time step的记忆过程理解为一连串的cell分别对不同的时刻输入信息的处理。
详细结构
一个典型的 LSTM 结构可以分别从输入、处理和输出三个角度来解析:
- 输入: 输入包含三个部分,分别是 cell 的信息𝐶t-1,它代表历史的记忆细胞(cell)状态信息的汇总;隐藏层的信息ht-1, 它是提取到的上个时刻的特征信息; 以及当前的输入𝑥t。
- 处理: 处理部分主要是由遗忘门、输入门、输出门组成。遗忘门由当前的输入和隐藏层信息控制对于历史的 cell 信息的遗忘程度;输入门是决定当前的输入和隐藏
层信息的利用程度;输出门是由当前的 cell 状态和输入决定输出。 - 输出: 分别是当前的 cell 状态𝐶’和当前的隐藏层信息h’。
遗忘门:
输入门:
细胞状态更新:
输出门:
Pytorch用法
参数介绍
class torch.nn.LSTM(*args, **kwargs)
参数:
- input_size:输入的特征维度
- hidden_size:隐藏层的特征维度(即输出的特征维度)
- num_layers:LSTM隐层的层数,默认为1
- bias:False则bih=0和bhh=0. 默认为True
- batch_first:True则输入输出的数据格式为 (batch, seq, feature)
- dropout:除最后一层,每一层的输出都进行dropout,默认为: 0
- bidirectional:True则为双向LSTM,默认为False
输入:input, (h0, c0)
输入数据格式:
input(seq_len, batch, input_size)
seq_len可以理解为一个视频有多少帧或者一个句子有多少单词,input_size就是一个帧或者一个单词可以用多少维的特征向量表示。
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)
输出:output, (hn, cn)
输出数据格式:
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)
使用实例
rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)#(input_size,hidden_size,num_layers)
input = torch.randn(5, 3, 10)#(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size)
c0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size)
output, (hn, cn) = rnn(input, (h0, c0))
output.shape #(seq_len, batch, output_size2)
torch.Size([5, 3, 40])
hn.shape #(num_layers2, batch, output_size)
torch.Size([2, 3, 20])
获取中间各层的隐藏层信息
lstm = nn.LSTM(3, 3)
inputs = [torch.randn(1, 3) for _ in range(5)]
# 这里的inputs的大小是一个含有5个1*3的tensor的列表,可以理解为一个5*1*3维的输入,其中5是seq_len,1是batch_size,3是input_size的大小
# 初始化隐藏状态
hidden = (torch.randn(1, 1, 3),
torch.randn(1, 1, 3))
for i in inputs:
# 将序列的元素逐个输入到LSTM,经过每步操作,hidden 的值包含了隐藏状态的信息
out, hidden = lstm(i.view(1, 1, -1), hidden)
关于变长输入
由于在视觉里变长输入的情况较少,这里只给出几个链接:
1.pytorch中如何处理RNN输入变长序列padding
https://zhuanlan.zhihu.com/p/34418001
2.教你几招搞定 LSTMs 的独门绝技
https://zhuanlan.zhihu.com/p/40391002