深入理解LSTM长短期记忆网络(基于d2l-ai项目)

深入理解LSTM长短期记忆网络(基于d2l-ai项目)

d2l-en d2l-ai/d2l-en: 是一个基于 Python 的深度学习教程,它使用了 SQLite 数据库存储数据。适合用于学习深度学习,特别是对于需要使用 Python 和 SQLite 数据库的场景。特点是深度学习教程、Python、SQLite 数据库。 d2l-en 项目地址: https://gitcode.com/gh_mirrors/d2/d2l-en

引言

循环神经网络(RNN)在处理序列数据时表现出色,但在处理长序列时面临着梯度消失和梯度爆炸的挑战。1997年,Hochreiter和Schmidhuber提出了长短期记忆网络(LSTM),这一创新架构有效解决了这些问题,成为深度学习领域的重要里程碑。

LSTM的核心思想

传统RNN的局限性

传统RNN存在两个主要问题:

  1. 长期依赖问题:当序列较长时,梯度在反向传播过程中会逐渐消失或爆炸
  2. 固定权重更新:所有时间步共享相同的权重更新机制,缺乏灵活性

LSTM的创新设计

LSTM通过引入三个关键机制解决了这些问题:

  1. 记忆单元:维护一个内部状态,可以长时间保留信息
  2. 门控机制:控制信息的流动,决定保留或丢弃哪些信息
  3. 自循环连接:固定权重为1的连接,确保梯度可以稳定传播

LSTM的详细架构

门控机制

LSTM包含三种不同类型的门:

  1. 输入门(Input Gate):决定新输入信息对内部状态的影响程度
  2. 遗忘门(Forget Gate):决定保留或丢弃多少先前的记忆
  3. 输出门(Output Gate):决定内部状态对当前输出的影响程度

数学表达

对于隐藏单元数h,批量大小n,输入维度d:

  • 输入门:$\mathbf{I}t = \sigma(\mathbf{X}t \mathbf{W}{\textrm{xi}} + \mathbf{H}{t-1} \mathbf{W}{\textrm{hi}} + \mathbf{b}\textrm{i})$
  • 遗忘门:$\mathbf{F}t = \sigma(\mathbf{X}t \mathbf{W}{\textrm{xf}} + \mathbf{H}{t-1} \mathbf{W}{\textrm{hf}} + \mathbf{b}\textrm{f})$
  • 输出门:$\mathbf{O}t = \sigma(\mathbf{X}t \mathbf{W}{\textrm{xo}} + \mathbf{H}{t-1} \mathbf{W}{\textrm{ho}} + \mathbf{b}\textrm{o})$

记忆单元更新

记忆单元的内部状态更新公式为: $$\mathbf{C}_t = \mathbf{F}t \odot \mathbf{C}{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t$$

其中$\tilde{\mathbf{C}}_t$是候选记忆状态,使用tanh激活函数计算。

从零实现LSTM

初始化参数

我们需要初始化以下参数:

  1. 输入门、遗忘门、输出门的权重和偏置
  2. 候选记忆状态的权重和偏置
class LSTMScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        
        # 初始化各门的权重和偏置
        self.W_xi, self.W_hi, self.b_i = self._init_triple()  # 输入门
        self.W_xf, self.W_hf, self.b_f = self._init_triple()  # 遗忘门
        self.W_xo, self.W_ho, self.b_o = self._init_triple()  # 输出门
        self.W_xc, self.W_hc, self.b_c = self._init_triple()  # 候选记忆

前向传播

实现LSTM的前向传播逻辑:

def forward(self, inputs, H_C=None):
    if H_C is None:
        # 初始化隐藏状态和记忆状态
        H = d2l.zeros((inputs.shape[1], self.num_hiddens))
        C = d2l.zeros((inputs.shape[1], self.num_hiddens))
    else:
        H, C = H_C
    
    outputs = []
    for X in inputs:
        # 计算各门的值
        I = d2l.sigmoid(d2l.matmul(X, self.W_xi) + d2l.matmul(H, self.W_hi) + self.b_i)
        F = d2l.sigmoid(d2l.matmul(X, self.W_xf) + d2l.matmul(H, self.W_hf) + self.b_f)
        O = d2l.sigmoid(d2l.matmul(X, self.W_xo) + d2l.matmul(H, self.W_ho) + self.b_o)
        
        # 计算候选记忆状态
        C_tilde = d2l.tanh(d2l.matmul(X, self.W_xc) + d2l.matmul(H, self.W_hc) + self.b_c)
        
        # 更新记忆状态和隐藏状态
        C = F * C + I * C_tilde
        H = O * d2l.tanh(C)
        outputs.append(H)
    
    return outputs, (H, C)

使用高级API实现

现代深度学习框架都提供了LSTM的高级实现,可以简化代码并提高效率:

class LSTM(d2l.RNN):
    def __init__(self, num_hiddens):
        super().__init__()
        self.save_hyperparameters()
        self.rnn = nn.LSTM(num_hiddens)  # 使用框架内置LSTM

    def forward(self, inputs, H_C=None):
        return self.rnn(inputs, H_C)

训练与评估

我们可以使用时间机器数据集来训练LSTM模型:

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)

总结

LSTM通过精心设计的门控机制,有效解决了传统RNN在处理长序列时的梯度问题。其核心创新在于:

  1. 引入记忆单元长期保存信息
  2. 使用门控机制灵活控制信息流
  3. 通过自循环连接稳定梯度传播

理解LSTM的工作原理对于掌握现代序列建模技术至关重要,它是许多先进模型(如GRU、Transformer)的基础。通过从零实现和高级API的对比学习,我们可以更深入地理解这一重要架构。

d2l-en d2l-ai/d2l-en: 是一个基于 Python 的深度学习教程,它使用了 SQLite 数据库存储数据。适合用于学习深度学习,特别是对于需要使用 Python 和 SQLite 数据库的场景。特点是深度学习教程、Python、SQLite 数据库。 d2l-en 项目地址: https://gitcode.com/gh_mirrors/d2/d2l-en

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

施笛娉Tabitha

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值