深入理解LSTM长短期记忆网络(基于d2l-ai项目)
引言
循环神经网络(RNN)在处理序列数据时表现出色,但在处理长序列时面临着梯度消失和梯度爆炸的挑战。1997年,Hochreiter和Schmidhuber提出了长短期记忆网络(LSTM),这一创新架构有效解决了这些问题,成为深度学习领域的重要里程碑。
LSTM的核心思想
传统RNN的局限性
传统RNN存在两个主要问题:
- 长期依赖问题:当序列较长时,梯度在反向传播过程中会逐渐消失或爆炸
- 固定权重更新:所有时间步共享相同的权重更新机制,缺乏灵活性
LSTM的创新设计
LSTM通过引入三个关键机制解决了这些问题:
- 记忆单元:维护一个内部状态,可以长时间保留信息
- 门控机制:控制信息的流动,决定保留或丢弃哪些信息
- 自循环连接:固定权重为1的连接,确保梯度可以稳定传播
LSTM的详细架构
门控机制
LSTM包含三种不同类型的门:
- 输入门(Input Gate):决定新输入信息对内部状态的影响程度
- 遗忘门(Forget Gate):决定保留或丢弃多少先前的记忆
- 输出门(Output Gate):决定内部状态对当前输出的影响程度
数学表达
对于隐藏单元数h,批量大小n,输入维度d:
- 输入门:$\mathbf{I}t = \sigma(\mathbf{X}t \mathbf{W}{\textrm{xi}} + \mathbf{H}{t-1} \mathbf{W}{\textrm{hi}} + \mathbf{b}\textrm{i})$
- 遗忘门:$\mathbf{F}t = \sigma(\mathbf{X}t \mathbf{W}{\textrm{xf}} + \mathbf{H}{t-1} \mathbf{W}{\textrm{hf}} + \mathbf{b}\textrm{f})$
- 输出门:$\mathbf{O}t = \sigma(\mathbf{X}t \mathbf{W}{\textrm{xo}} + \mathbf{H}{t-1} \mathbf{W}{\textrm{ho}} + \mathbf{b}\textrm{o})$
记忆单元更新
记忆单元的内部状态更新公式为: $$\mathbf{C}_t = \mathbf{F}t \odot \mathbf{C}{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t$$
其中$\tilde{\mathbf{C}}_t$是候选记忆状态,使用tanh激活函数计算。
从零实现LSTM
初始化参数
我们需要初始化以下参数:
- 输入门、遗忘门、输出门的权重和偏置
- 候选记忆状态的权重和偏置
class LSTMScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
# 初始化各门的权重和偏置
self.W_xi, self.W_hi, self.b_i = self._init_triple() # 输入门
self.W_xf, self.W_hf, self.b_f = self._init_triple() # 遗忘门
self.W_xo, self.W_ho, self.b_o = self._init_triple() # 输出门
self.W_xc, self.W_hc, self.b_c = self._init_triple() # 候选记忆
前向传播
实现LSTM的前向传播逻辑:
def forward(self, inputs, H_C=None):
if H_C is None:
# 初始化隐藏状态和记忆状态
H = d2l.zeros((inputs.shape[1], self.num_hiddens))
C = d2l.zeros((inputs.shape[1], self.num_hiddens))
else:
H, C = H_C
outputs = []
for X in inputs:
# 计算各门的值
I = d2l.sigmoid(d2l.matmul(X, self.W_xi) + d2l.matmul(H, self.W_hi) + self.b_i)
F = d2l.sigmoid(d2l.matmul(X, self.W_xf) + d2l.matmul(H, self.W_hf) + self.b_f)
O = d2l.sigmoid(d2l.matmul(X, self.W_xo) + d2l.matmul(H, self.W_ho) + self.b_o)
# 计算候选记忆状态
C_tilde = d2l.tanh(d2l.matmul(X, self.W_xc) + d2l.matmul(H, self.W_hc) + self.b_c)
# 更新记忆状态和隐藏状态
C = F * C + I * C_tilde
H = O * d2l.tanh(C)
outputs.append(H)
return outputs, (H, C)
使用高级API实现
现代深度学习框架都提供了LSTM的高级实现,可以简化代码并提高效率:
class LSTM(d2l.RNN):
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = nn.LSTM(num_hiddens) # 使用框架内置LSTM
def forward(self, inputs, H_C=None):
return self.rnn(inputs, H_C)
训练与评估
我们可以使用时间机器数据集来训练LSTM模型:
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)
总结
LSTM通过精心设计的门控机制,有效解决了传统RNN在处理长序列时的梯度问题。其核心创新在于:
- 引入记忆单元长期保存信息
- 使用门控机制灵活控制信息流
- 通过自循环连接稳定梯度传播
理解LSTM的工作原理对于掌握现代序列建模技术至关重要,它是许多先进模型(如GRU、Transformer)的基础。通过从零实现和高级API的对比学习,我们可以更深入地理解这一重要架构。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考