PyTorch深度学习教程:RNN与LSTM架构详解

PyTorch深度学习教程:RNN与LSTM架构详解

NYU-DLSP20 NYU Deep Learning Spring 2020 NYU-DLSP20 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-Deep-Learning

概述:序列数据处理与RNN基础

循环神经网络(RNN)是处理序列数据的重要架构。序列数据本质上是一维时间轴上的数据流,但也可以扩展到二维空间(如文本处理中的双向序列)。与传统的前馈神经网络(Vanilla NN)相比,RNN的关键特性在于其具有记忆功能。

传统神经网络 vs 循环神经网络

传统神经网络(图1)是纯粹的前馈结构,当前输出仅取决于当前输入,类似于数字电路中的组合逻辑。而RNN(图2)的输出不仅取决于当前输入,还依赖于系统的历史状态,这类似于数字电路中的时序逻辑,使用"触发器"来保持状态。

传统神经网络架构 图1:传统神经网络架构

RNN架构 图2:RNN架构

在Yann LeCun提出的表示法中(图3-4),神经元之间的连接形状代表了张量之间的映射关系,即通过仿射变换(旋转加扭曲)将输入向量转换为隐藏表示,再进一步转换为输出。

RNN的四种架构类型及应用实例

1. 向量到序列(Vec2Seq)

输入为单个向量(如图像),输出为符号序列(如描述语句)。典型应用是图像描述生成(图5-6)。这类自回归网络的特点是:当前输出作为下一时间步的输入。

Vec2Seq架构 图5:向量到序列架构

图像描述生成示例 图6:图像描述生成实例

2. 序列到向量(Seq2Vec)

连续输入符号序列,最终输出单个向量。应用包括Python代码解释器(图7-9),网络能理解并执行输入的代码序列。

Seq2Vec架构 图7:序列到向量架构

3. 序列到向量到序列(Seq2Vec2Seq)

曾用于机器翻译的标准架构(图10)。先将输入序列压缩为语义向量,再解码为目标语言。通过PCA分析潜在空间,可发现语义相关的词汇会自然聚类(图11-13)。

Seq2Vec2Seq架构 图10:序列到向量到序列架构

4. 序列到序列(Seq2Seq)

实时输入输出架构(图15),如手机输入预测(T9)和自动文本补全(图16)。训练于科幻小说的RNN写作助手能根据开头生成连贯后续。

Seq2Seq架构 图15:序列到序列架构

随时间反向传播(BPTT)

模型架构与训练

RNN训练需使用BPTT算法。图17展示了RNN的两种表示:循环形式(左)和按时间展开形式(右)。隐藏状态计算为:

$$ h[t] = g(W_h\begin{bmatrix}x[t]\h[t-1]\end{bmatrix} + b_h) $$

BPTT示意图 图17:随时间反向传播

语言模型中的批处理

处理符号序列时,可将文本批量化为不同尺寸(图18)。设置BPTT周期T后,需在最后一步切断梯度传播(使用.detach()),防止无限传播(图19)。

批处理示意图 图18:批处理示例

梯度消失与爆炸问题

问题分析

RNN中的矩阵变换可能导致梯度随时间膨胀(爆炸)或收缩(消失)。如图20所示,初始较大的梯度(亮色)经过几次旋转后可能完全消失。

梯度消失问题 图20:梯度消失问题

解决方案:跳跃连接

通过门控机制(如LSTM)实现跳跃连接(图21),将网络分割为多个子网络,选择性控制梯度传播路径。

跳跃连接 图21:跳跃连接解决方案

长短期记忆网络(LSTM)

架构详解

LSTM通过三个门控单元(输入门、遗忘门、输出门)精细控制信息流(图22):

  • 输入门(黄色):调节新信息的写入
  • 遗忘门:控制历史记忆的保留
  • 输出门:控制信息的输出

数学表示为: $$ \begin{aligned} i[t] &= \sigma(W_i[h[t-1],x[t]]+b_i) \ f[t] &= \sigma(W_f[h[t-1],x[t]]+b_f) \ o[t] &= \sigma(W_o[h[t-1],x[t]]+b_o) \ \tilde{c}[t] &= \tanh(W_c[h[t-1],x[t]]+b_c) \ c[t] &= f[t]\odot c[t-1] + i[t]\odot \tilde{c}[t] \ h[t] &= o[t]\odot \tanh(c[t]) \end{aligned} $$

LSTM架构 图22:LSTM核心架构

门控机制可视化

通过设置不同的门控状态(0/1),LSTM可以实现:

  • 重置记忆(图26):遗忘门=0,输入门=0
  • 保持记忆(图27):遗忘门=1,输入门=0
  • 写入记忆(图28):遗忘门=0,输入门=1

LSTM记忆单元 图25:记忆单元可视化

实战应用:序列分类

在PyTorch中实现序列分类任务时,通常将输入元素和目标表示为局部向量。通过LSTM处理变长序列,最后使用全连接层进行分类。关键步骤包括:

  1. 嵌入层:将离散符号转换为连续向量
  2. LSTM层:处理变长序列
  3. 全连接层:输出分类结果
  4. 使用交叉熵损失进行训练

这种架构可广泛应用于文本分类、情感分析等任务,展现了RNN/LSTM处理序列数据的强大能力。

NYU-DLSP20 NYU Deep Learning Spring 2020 NYU-DLSP20 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-Deep-Learning

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

方拓行Sandra

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值