PyTorch深度学习教程:RNN与LSTM架构详解
概述:序列数据处理与RNN基础
循环神经网络(RNN)是处理序列数据的重要架构。序列数据本质上是一维时间轴上的数据流,但也可以扩展到二维空间(如文本处理中的双向序列)。与传统的前馈神经网络(Vanilla NN)相比,RNN的关键特性在于其具有记忆功能。
传统神经网络 vs 循环神经网络
传统神经网络(图1)是纯粹的前馈结构,当前输出仅取决于当前输入,类似于数字电路中的组合逻辑。而RNN(图2)的输出不仅取决于当前输入,还依赖于系统的历史状态,这类似于数字电路中的时序逻辑,使用"触发器"来保持状态。
图1:传统神经网络架构
图2:RNN架构
在Yann LeCun提出的表示法中(图3-4),神经元之间的连接形状代表了张量之间的映射关系,即通过仿射变换(旋转加扭曲)将输入向量转换为隐藏表示,再进一步转换为输出。
RNN的四种架构类型及应用实例
1. 向量到序列(Vec2Seq)
输入为单个向量(如图像),输出为符号序列(如描述语句)。典型应用是图像描述生成(图5-6)。这类自回归网络的特点是:当前输出作为下一时间步的输入。
图5:向量到序列架构
图6:图像描述生成实例
2. 序列到向量(Seq2Vec)
连续输入符号序列,最终输出单个向量。应用包括Python代码解释器(图7-9),网络能理解并执行输入的代码序列。
图7:序列到向量架构
3. 序列到向量到序列(Seq2Vec2Seq)
曾用于机器翻译的标准架构(图10)。先将输入序列压缩为语义向量,再解码为目标语言。通过PCA分析潜在空间,可发现语义相关的词汇会自然聚类(图11-13)。
图10:序列到向量到序列架构
4. 序列到序列(Seq2Seq)
实时输入输出架构(图15),如手机输入预测(T9)和自动文本补全(图16)。训练于科幻小说的RNN写作助手能根据开头生成连贯后续。
图15:序列到序列架构
随时间反向传播(BPTT)
模型架构与训练
RNN训练需使用BPTT算法。图17展示了RNN的两种表示:循环形式(左)和按时间展开形式(右)。隐藏状态计算为:
$$ h[t] = g(W_h\begin{bmatrix}x[t]\h[t-1]\end{bmatrix} + b_h) $$
图17:随时间反向传播
语言模型中的批处理
处理符号序列时,可将文本批量化为不同尺寸(图18)。设置BPTT周期T后,需在最后一步切断梯度传播(使用.detach()),防止无限传播(图19)。
图18:批处理示例
梯度消失与爆炸问题
问题分析
RNN中的矩阵变换可能导致梯度随时间膨胀(爆炸)或收缩(消失)。如图20所示,初始较大的梯度(亮色)经过几次旋转后可能完全消失。
图20:梯度消失问题
解决方案:跳跃连接
通过门控机制(如LSTM)实现跳跃连接(图21),将网络分割为多个子网络,选择性控制梯度传播路径。
图21:跳跃连接解决方案
长短期记忆网络(LSTM)
架构详解
LSTM通过三个门控单元(输入门、遗忘门、输出门)精细控制信息流(图22):
- 输入门(黄色):调节新信息的写入
- 遗忘门:控制历史记忆的保留
- 输出门:控制信息的输出
数学表示为: $$ \begin{aligned} i[t] &= \sigma(W_i[h[t-1],x[t]]+b_i) \ f[t] &= \sigma(W_f[h[t-1],x[t]]+b_f) \ o[t] &= \sigma(W_o[h[t-1],x[t]]+b_o) \ \tilde{c}[t] &= \tanh(W_c[h[t-1],x[t]]+b_c) \ c[t] &= f[t]\odot c[t-1] + i[t]\odot \tilde{c}[t] \ h[t] &= o[t]\odot \tanh(c[t]) \end{aligned} $$
图22:LSTM核心架构
门控机制可视化
通过设置不同的门控状态(0/1),LSTM可以实现:
- 重置记忆(图26):遗忘门=0,输入门=0
- 保持记忆(图27):遗忘门=1,输入门=0
- 写入记忆(图28):遗忘门=0,输入门=1
图25:记忆单元可视化
实战应用:序列分类
在PyTorch中实现序列分类任务时,通常将输入元素和目标表示为局部向量。通过LSTM处理变长序列,最后使用全连接层进行分类。关键步骤包括:
- 嵌入层:将离散符号转换为连续向量
- LSTM层:处理变长序列
- 全连接层:输出分类结果
- 使用交叉熵损失进行训练
这种架构可广泛应用于文本分类、情感分析等任务,展现了RNN/LSTM处理序列数据的强大能力。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考