PyTorch深度学习教程:循环神经网络与LSTM架构详解
概述
循环神经网络(Recurrent Neural Networks, RNN)是处理序列数据的强大工具。与传统的神经网络不同,RNN能够捕捉数据中的时间依赖性,这使得它在自然语言处理、语音识别、时间序列预测等领域表现出色。
Vanilla神经网络 vs 循环神经网络
Vanilla神经网络
Vanilla神经网络(普通前馈神经网络)的结构特点是:
- 输入层、隐藏层和输出层组成
- 每一时刻的输出仅取决于当前输入
- 类似于数字电路中的组合逻辑电路
graph LR
A[输入x] --> B[隐藏层]
B --> C[输出y]
循环神经网络
RNN的独特之处在于:
- 引入隐藏状态(hidden state)作为记忆单元
- 当前输出不仅取决于当前输入,还取决于前一时刻的隐藏状态
- 类似于数字电路中的时序逻辑电路
graph LR
A[输入x_t] --> B[隐藏层h_t]
B --> C[输出y_t]
B --> D[隐藏状态h_t+1]
D --> B
RNN的四种典型架构
1. 向量到序列(Vec2Seq)
应用场景:图像描述生成
- 输入:单个图像向量
- 输出:描述图像的单词序列
- 特点:自回归网络,前一输出作为下一输入
示例: 输入图像 → 输出:"一只黄色校车停在停车场"
2. 序列到向量(Seq2Vec)
应用场景:程序执行结果预测
- 输入:Python代码序列
- 输出:程序执行结果
- 特点:整个序列处理后产生单一输出
示例:
x = 3
y = 5
print(x + y)
→ 输出:8
3. 序列到向量到序列(Seq2Vec2Seq)
应用场景:机器翻译
- 编码器:将源语言序列压缩为语义向量
- 解码器:将语义向量展开为目标语言序列
- 特点:曾是最先进的翻译模型架构
语义空间特性:
- 相似语义的词在潜在空间中聚集
- 可进行"king - man + woman = queen"等向量运算
4. 序列到序列(Seq2Seq)
应用场景:实时文本预测
- 输入:部分文本序列
- 输出:后续文本预测
- 特点:输入输出同步进行
示例: 输入:"The rings of Saturn glittered while" → 输出:"two men looked at each other"
基于时间的反向传播(BPTT)
算法原理
- 将RNN按时间步展开
- 前向计算各时间步的输出和损失
- 反向传播时考虑时间维度上的梯度流动
- 梯度通过时间步反向传播
数学表达
隐藏状态计算: $$ h[t] = g(W_{hx}x[t] + W_{hh}h[t-1] + b_h) $$
输出计算: $$ \hat{y}[t] = g(W_yh[t] + b_y) $$
分批处理技巧
- 将文本序列划分为批次
- 垂直方向保持时间连续性
- 使用
.detach()
防止梯度无限传播
示例批次:
a g m s
b h n t
c i o u
梯度问题与解决方案
梯度消失与爆炸
原因:
- 重复的矩阵乘法导致梯度指数级变化
- 特征值>1 → 梯度爆炸
- 特征值<1 → 梯度消失
影响:
- 难以学习长期依赖关系
- 训练过程不稳定
解决方案:LSTM架构
长短期记忆网络(LSTM)通过引入:
- 输入门:控制新信息的流入
- 遗忘门:决定保留多少旧记忆
- 输出门:控制输出信息
- 细胞状态:长期记忆的传递路径
graph LR
C[t-1] --> F[遗忘门]
C[t-1] --> I[输入门]
I --> C[t]
F --> C[t]
C[t] --> O[输出门]
O --> h[t]
其他改进方案
- 梯度裁剪:防止梯度爆炸
- 残差连接:缓解梯度消失
- 层归一化:稳定训练过程
总结
RNN及其变体LSTM是处理序列数据的强大工具。理解其架构和工作原理对于在实际任务中应用和调优这些模型至关重要。虽然Transformer等新架构在某些任务上表现更优,RNN/LSTM仍然是深度学习工具包中的重要组成部分,特别适合处理中等长度的序列数据。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考