LSTM的记忆能力实验
飞桨AI Studio星河社区-人工智能学习与实训社区 (baidu.com)
长短期记忆网络(Long Short-Term Memory Network,LSTM)是一种可以有效缓解长程依赖问题的循环神经网络.LSTM 的特点是引入了一个新的内部状态(Internal State) 和门控机制(Gating Mechanism).不同时刻的内部状态以近似线性的方式进行传递,从而缓解梯度消失或梯度爆炸问题.同时门控机制进行信息筛选,可以有效地增加记忆能力.例如,输入门可以让网络忽略无关紧要的输入信息,遗忘门可以使得网络保留有用的历史信息.在上一节的数字求和任务中,如果模型能够记住前两个非零数字,同时忽略掉一些不重要的干扰信息,那么即时序列很长,模型也有效地进行预测.
LSTM 模型在第 t 步时,循环单元的内部结构如图所示.
提醒:为了和代码的实现保存一致性,这里使用形状为 (样本数量 × 序列长度 × 特征维度) 的张量来表示一组样本.
假设一组输入序列为,,其中B为批大小,L为序列长度,M为输入特征维度,LSTM从从左到右依次扫描序列,并通过循环单元计算更新每一时刻的状态内部状态
,和输出状态
。
具体计算分为三步:
(1)计算三个“门”
在时刻t,LSTM的循环单元将当前时刻的输入,,与上一时刻的输出状态
,计算一组输入门
、遗忘门
和输出门
,其计算公式为
其中为可学习的参数,σ表示Logistic函数,将“门”的取值控制在(0,1)区间。这里的“门”都是B个样本组成的矩阵,每一行为一个样本的“门”向量。
(2)计算内部状态
首先计算候选内部状态:
其中 为可学习的参数。
使用遗忘门和输入门,计算时刻t的内部状态:
其中⊙为逐元素积。
(3)计算输出状态
当前LSTM单元状态(候选状态)的计算公式为: LSTM单元状态向量Ct和Ht的计算公式为
LSTM循环单元结构的输入是t−1时刻内部状态向量,和隐状态向量
,输出是当前时刻t的状态向量
,通过LSTM循环单元,整个网络可以建立较长距离的时序依赖关系。
通过学习这些门的设置,LSTM可以选择性地忽略或者强化当前的记忆或是输入信息,帮助网络更好地学习长句子的语义信息。
模型构建
在本实验中,我们将使用上个实验中定义Model_RNN4SeqClass模型,并构建 LSTM 算子.只需要实例化 LSTM 算,并传入Model_RNN4SeqClass模型,就可以用 LSTM 进行数字求和实验
LSTM层
LSTM层的代码与SRN层结构相似,只是在SRN层的基础上增加了内部状态、输入门、遗忘门和输出门的定义和计算。这里LSTM层的输出也依然为序列的最后一个位置的隐状态向量。代码实现如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, Wi_attr=None, Wf_attr=None, Wo_attr=None, Wc_attr=None,
Ui_attr=None, Uf_attr=None, Uo_attr=None, Uc_attr=None, bi_attr=None, bf_attr=None,
bo_attr=None, bc_attr=None):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 初始化模型参数
if Wi_attr is None:
Wi = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wi = torch.tensor(Wi_attr, dtype=torch.float32)
self.W_i = torch.nn.Parameter(Wi)
if Wf_attr is None:
Wf = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wf = torch.tensor(Wf_attr, dtype=torch.float32)
self.W_f = torch.nn.Parameter(Wf)
if Wo_attr is None:
Wo = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wo = torch.tensor(Wo_attr, dtype=torch.float32)
self.W_o = torch.nn.Parameter(Wo)
if Wc_attr is None:
Wc = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wc = torch.tensor(Wc_attr, dtype=torch.float32)
self.W_c = torch.nn.Parameter(Wc)
if Ui_attr is None:
Ui = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Ui = torch.tensor(Ui_attr, dtype=torch.float32)
self.U_i = torch.nn.Parameter(Ui)
if Uf_attr is None:
Uf = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Uf = torch.tensor(Uf_attr, dtype=torch.float32)
self.U_f = torch.nn.Parameter(Uf)
if Uo_attr is None:
Uo = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Uo = torch.tensor(Uo_attr, dtype=torch.float32)
self.U_o = torch.nn.Parameter(Uo)
if Uc_attr is None:
Uc = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Uc = torch.tensor(Uc_attr, dtype=torch.float32)
self.U_c = torch.nn.Parameter(Uc)
if bi_attr is None:
bi = torch.zeros(size=[1, hidden_size], dtype=torch.float32)
else:
bi = torch.tensor(bi_attr, dtype=torch.float32)
self.b_i = torch.nn.Parameter(bi)
if bf_attr is None:
bf = torch.zeros(size=[1, hidde