循环神经网络(RNN)作为处理序列数据的核心模型,在自然语言处理、时间序列分析等领域发挥着关键作用。本文将系统梳理传统RNN、LSTM和GRU的内部机制、数学表达与实践应用,通过统一案例对比三者的计算过程,帮助读者深入理解序列模型的进化脉络。
一、RNN基础:序列建模的核心思想
1.1 RNN的本质与核心机制
RNN(Recurrent Neural Network)的核心创新在于"循环记忆"机制——将上一时间步的隐藏状态与当前输入结合,形成对序列依赖关系的建模能力。其本质是通过参数共享实现对时间维度的特征提取,数学上表现为:
h t = f ( W ⋅ [ x t , h t − 1 ] + b ) h_t = f(W \cdot [x_t, h_{t-1}] + b) ht=f(W⋅[xt,ht−1]+b)
- x t x_t xt:当前时间步输入向量
- h t − 1 h_{t-1} ht−1:上一时间步隐藏状态
- f f f:激活函数(通常为tanh或sigmoid)
这种结构使得RNN能够捕捉序列中的短期依赖,例如判断"我爱吃苹果"中"苹果"的词性需依赖前文"吃"的动作。
1.2 应用场景与结构分类
RNN按输入输出结构可分为四类:
- N vs N:输入输出等长,适用于语音识别中的帧级标注
- N vs 1:序列输入→单一输出,典型如文本分类
- 1 vs N:单一输入→序列输出,常用于图片生成描述
- N vs M:seq2seq架构,输入输出长度不限,是机器翻译的基础
二、传统RNN:序列模型的起点
结构图
2.1 内部结构与数学表达
传统RNN的计算流程可拆解为:
- 输入拼接: [ x t , h t − 1 ] [x_t, h_{t-1}] [xt,ht−1]
- 线性变换: W ⋅ [ x t , h t − 1 ] + b W \cdot [x_t, h_{t-1}] + b W⋅[xt,ht−1]+b
- 激活输出: h t = tanh ( ⋅ ) h_t = \tanh(\cdot) ht=tanh(⋅)
以一个3维输入、4维隐藏状态的RNN为例,其参数矩阵为:
- 输入权重 W i h ∈ R 4 × 3 W_{ih} \in \mathbb{R}^{4 \times 3} Wih∈R4×3
- 隐藏权重 W h h ∈ R 4 × 4 W_{hh} \in \mathbb{R}^{4 \times 4} Whh∈R4×4
- 偏置 b ∈ R 4 b \in \mathbb{R}^4 b∈R4
2.2 计算示例
人名"Bob"的特征提取
假设:
- 字符编码:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
- RNN参数: W i h = [ [ 1 , 0 , 0 ] , [ 0 , 1 , 0 ] , [ 0 , 0 , 1 ] , [ 1 , 1 , 1 ] ] W_{ih} = [[1,0,0],[0,1,0],[0,0,1],[1,1,1]] Wih=[[1,0,0],[0,1,0],[0,0,1],[1,1,1]](简化示例)
- W h h = [ [ 1 , 0 , 0 , 0 ] , [ 0 , 1 , 0 , 0 ] , [ 0 , 0 , 1 , 0 ] , [ 0 , 0 , 0 , 1 ] ] W_{hh} = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] Whh=[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]
计算步骤:
-
时间步1:输入’B’
h 1 = tanh ( W i h ⋅ ′ B ′ + W h h ⋅ h 0 ) = tanh ( [ 0.1 , 0 , 0 , 0.1 ] ) h_1 = \tanh(W_{ih} \cdot 'B' + W_{hh} \cdot h_0) = \tanh([0.1, 0, 0, 0.1]) h1=tanh(Wih⋅′B′+Whh⋅h0)=tanh([0.1,0,0,0.1])
h 1 = [ 0.099 , 0 , 0 , 0.099 ] h_1 = [0.099, 0, 0, 0.099] h1=[0.099,0,0,0.099](假设h0全0) -
时间步2:输入’o’
h 2 = tanh ( W i h ⋅ ′ o ′ + W h h ⋅ h 1 ) h_2 = \tanh(W_{ih} \cdot 'o' + W_{hh} \cdot h_1) h2=tanh(Wih⋅′o′+Whh⋅h1)
= tanh ( [ 0 , 0.1 , 0 , 0.1 + 0.099 ] ) = [ 0 , 0.099 , 0 , 0.197 ] = \tanh([0, 0.1, 0, 0.1+0.099]) = [0, 0.099, 0, 0.197] =tanh([0,0.1,0,0.1+0.099])=[0,0.099,0,0.197] -
时间步3:输入’b’
h 3 = tanh ( [ 0 , 0 , 0.1 , 0.1 + 0.197 ] ) = [ 0 , 0 , 0.099 , 0.292 ] h_3 = \tanh([0, 0, 0.1, 0.1+0.197]) = [0, 0, 0.099, 0.292] h3=tanh([0,0,0.1,0.1+0.197])=[0,0,0.099,0.292]
最终隐藏状态 h 3 h_3 h3已经包含"Bob"的所有信息,即为"Bob"的序列特征表示。
2.3 RNN在Pytorch中的API
- RNN类定义与核心参数
torch.nn.RNN(
input_size, # 输入特征维度
hidden_size, # 隐藏状态维度
num_layers=1, # 堆叠的RNN层数
nonlinearity='tanh', # 非线性激活函数 'tanh' 或 'relu'
bias=True, # 是否使用偏置
batch_first=False, # 输入格式是否为(batch, seq, feature)
dropout=0, # Dropout概率
bidirectional=False # 是否为双向RNN
)
-
输入与输出格式
输入参数:
- input:输入序列,形状为
(seq_len, batch, input_size)
(默认)或(batch, seq_len, input_size)
(batch_first=True
) - h_0:初始隐藏状态,形状为
(num_layers * num_directions, batch, hidden_size)
输出参数:
- output:所有时间步的隐藏状态,形状为
(seq_len, batch, hidden_size * num_directions)
- h_n:最后一个时间步的隐藏状态,形状同
h_0
- input:输入序列,形状为
-
关键属性与方法
权重矩阵:
- weight_ih_l[k]:第
k
层的输入到隐藏的权重 - weight_hh_l[k]:第
k
层的隐藏到隐藏的权重 - bias_ih_l[k] 和 bias_hh_l[k]:对应偏置
前向传播方法:
- weight_ih_l[k]:第
output, h_n = rnn(input, h_0)
2.4 代码示例
- 基本用法
import torch
import torch.nn as nn
# 创建RNN模型
rnn = nn.RNN(
input_size=10, # 输入特征维度
hidden_size=20, # 隐藏状态维度
num_layers=2, # 2层RNN堆叠
batch_first=True, # 使用(batch, seq, feature)格式
bidirectional=True # 双向RNN
)
# 准备输入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10) # 输入序列
# 初始化隐藏状态(可选)
h0 = torch.zeros(2*2, batch_size, 20) # 2层 * 双向
# 前向传播
output, hn = rnn(x, h0)
# 输出形状分析
print("Output shape:", output.shape) # (3, 5, 40) [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape) # (4, 3, 20) [layers*directions, batch, hidden]
- 获取最后时间步的隐藏状态
# 方法1:从output中获取
last_output = output[:, -1, :] # (batch, hidden*directions)
# 方法2:从hn中获取
last_hidden = hn[-2:] if rnn.bidirectional else hn[-1] # 双向时需拼接两个方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if rnn.bidirectional else last_hidden
2.5 优缺点与梯度问题
- 优势:结构简单,参数量少,短序列计算效率高
- 致命缺陷:长序列中梯度消失严重,如:
梯度连乘公式: ∇ = ∏ i = 1 n σ ′ ( z i ) ⋅ w i \nabla = \prod_{i=1}^n \sigma'(z_i) \cdot w_i ∇=∏i=1nσ′(zi)⋅wi
当 w i < 1 w_i < 1 wi<1时,连乘导致梯度趋近于0,无法更新远层参数
三、LSTM:门控机制破解长期依赖
3.1 四大门控机制详解
LSTM通过引入门控系统,将传统RNN的单一隐藏状态拆分为:
- 细胞状态C:长期记忆载体
- 隐藏状态h:短期特征表示
核心公式组:
-
遗忘门:决定丢弃历史信息
- 功能:决定丢弃细胞状态中的哪些历史信息。
- 计算过程:
- 输入当前输入 x t x_t xt 和上一时刻隐藏状态 h t − 1 h_{t-1} ht−1,拼接后通过全连接层
- f t f_t ft 是0到1之间的门值,1表示“完全保留”,0表示“完全遗忘”。
-
输入门:筛选新信息
- 功能:决定当前输入的新信息中哪些需要存储到细胞状态。
- 计算过程:
- 生成输入门门值 i t i_t it(类似遗忘门,通过sigmoid激活):
- 生成候选细胞状态
C
~
t
\tilde{C}_t
C~t(通过tanh激活):
-
细胞状态更新:
- 功能:存储长期记忆,通过门控机制更新。
- 更新过程:
- f t ∗ C t − 1 f_t * C_{t-1} ft∗Ct−1:遗忘门作用于旧细胞状态,丢弃部分历史信息;
- i t ∗ C ~ t i_t *\tilde{C}_t it∗C~t:输入门筛选新信息并与候选状态结合。
-
输出门:生成当前隐藏状态
- 功能:决定细胞状态中的哪些信息作为当前隐藏状态输出。
- 计算过程:
- 生成输出门门值 o t o_t ot
- 细胞状态通过tanh激活后,与输出门值相乘得到隐藏状态
3.2 "Bob"案例的LSTM完整计算示例
为便于与RNN对比,我们保持输入维度、隐藏状态维度一致,并使用相似的参数设置:
假设条件:
- 输入维度=3,隐藏状态维度=4
- 字符编码:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
- LSTM参数(简化后):
(注:每个权重矩阵实际为4x7,因拼接h_{t-1}(4维)和x_t(3维))W_f = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_i = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_c = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_o = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
详细计算过程:
时间步1:输入 ‘B’ = [0.1, 0, 0]
-
遗忘门计算:
f 1 = σ ( W f ⋅ [ h 0 , ′ B ′ ] + b f ) f_1 = \sigma(W_f \cdot [h_0, 'B'] + b_f) f1=σ(Wf⋅[h0,′B′]+bf)
假设 b f = [ 0 , 0 , 0 , 0 ] b_f=[0,0,0,0] bf=[0,0,0,0],则:
W f ⋅ [ h 0 , ′ B ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0 , 0 , 0 , 0 , 0.1 , 0 , 0 ] T = [ 0.1 , 0 , 0 , 0 ] W_f \cdot [h_0, 'B'] = [[1,0,0,1,0,0,0], ...] \cdot [0,0,0,0,0.1,0,0]^T = [0.1, 0, 0, 0] Wf⋅[h0,′B′]=[[1,0,0,1,0,0,0],...]⋅[0,0,0,0,0.1,0,0]T=[0.1,0,0,0]
f 1 = σ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] f_1 = \sigma([0.1, 0, 0, 0]) = [0.525, 0.5, 0.5, 0.5] f1=σ([0.1,0,0,0])=[0.525,0.5,0.5,0.5] -
输入门计算:
i 1 = σ ( W i ⋅ [ h 0 , ′ B ′ ] + b i ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] i_1 = \sigma(W_i \cdot [h_0, 'B'] + b_i) = [0.525, 0.5, 0.5, 0.5] i1=σ(Wi⋅[h0,′B′]+bi)=[0.525,0.5,0.5,0.5]
C ~ 1 = tanh ( W c ⋅ [ h 0 , ′ B ′ ] + b c ) = tanh ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.099 , 0 , 0 , 0 ] \tilde{C}_1 = \tanh(W_c \cdot [h_0, 'B'] + b_c) = \tanh([0.1, 0, 0, 0]) = [0.099, 0, 0, 0] C~1=tanh(Wc⋅[h0,′B′]+bc)=tanh([0.1,0,0,0])=[0.099,0,0,0] -
细胞状态更新:
C 1 = f 1 ⊙ C 0 + i 1 ⊙ C ~ 1 C_1 = f_1 \odot C_0 + i_1 \odot \tilde{C}_1 C1=f1⊙C0+i1⊙C~1
= [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0 , 0 , 0 , 0 ] + [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.099 , 0 , 0 , 0 ] = [0.525, 0.5, 0.5, 0.5] \odot [0,0,0,0] + [0.525, 0.5, 0.5, 0.5] \odot [0.099, 0, 0, 0] =[0.525,0.5,0.5,0.5]⊙[0,0,0,0]+[0.525,0.5,0.5,0.5]⊙[0.099,0,0,0]
= [ 0.052 , 0 , 0 , 0 ] = [0.052, 0, 0, 0] =[0.052,0,0,0] -
输出门与隐藏状态:
o 1 = σ ( W o ⋅ [ h 0 , ′ B ′ ] + b o ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] o_1 = \sigma(W_o \cdot [h_0, 'B'] + b_o) = [0.525, 0.5, 0.5, 0.5] o1=σ(Wo⋅[h0,′B′]+bo)=[0.525,0.5,0.5,0.5]
h 1 = o 1 ⊙ tanh ( C 1 ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.052 , 0 , 0 , 0 ] = [ 0.027 , 0 , 0 , 0 ] h_1 = o_1 \odot \tanh(C_1) = [0.525, 0.5, 0.5, 0.5] \odot [0.052, 0, 0, 0] = [0.027, 0, 0, 0] h1=o1⊙tanh(C1)=[0.525,0.5,0.5,0.5]⊙[0.052,0,0,0]=[0.027,0,0,0]
时间步2:输入 ‘o’ = [0, 0.1, 0]
-
遗忘门计算:
W f ⋅ [ h 1 , ′ o ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0.027 , 0 , 0 , 0 , 0 , 0.1 , 0 ] T = [ 0.027 , 0.1 , 0 , 0 ] W_f \cdot [h_1, 'o'] = [[1,0,0,1,0,0,0], ...] \cdot [0.027,0,0,0,0,0.1,0]^T = [0.027, 0.1, 0, 0] Wf⋅[h1,′o′]=[[1,0,0,1,0,0,0],...]⋅[0.027,0,0,0,0,0.1,0]T=[0.027,0.1,0,0]
f 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] f_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] f2=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5] -
输入门计算:
i 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] i_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] i2=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5]
C ~ 2 = tanh ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.027 , 0.099 , 0 , 0 ] \tilde{C}_2 = \tanh([0.027, 0.1, 0, 0]) = [0.027, 0.099, 0, 0] C~2=tanh([0.027,0.1,0,0])=[0.027,0.099,0,0] -
细胞状态更新:
C 2 = f 2 ⊙ C 1 + i 2 ⊙ C ~ 2 C_2 = f_2 \odot C_1 + i_2 \odot \tilde{C}_2 C2=f2⊙C1+i2⊙C~2
= [ 0.026 , 0 , 0 , 0 ] + [ 0.014 , 0.052 , 0 , 0 ] = [ 0.04 , 0.052 , 0 , 0 ] = [0.026, 0, 0, 0] + [0.014, 0.052, 0, 0] = [0.04, 0.052, 0, 0] =[0.026,0,0,0]+[0.014,0.052,0,0]=[0.04,0.052,0,0] -
输出门与隐藏状态:
o 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] o_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] o2=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5]
h 2 = o 2 ⊙ tanh ( C 2 ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] ⊙ [ 0.04 , 0.052 , 0 , 0 ] = [ 0.02 , 0.027 , 0 , 0 ] h_2 = o_2 \odot \tanh(C_2) = [0.507, 0.525, 0.5, 0.5] \odot [0.04, 0.052, 0, 0] = [0.02, 0.027, 0, 0] h2=o2⊙tanh(C2)=[0.507,0.525,0.5,0.5]⊙[0.04,0.052,0,0]=[0.02,0.027,0,0]
时间步3:输入 ‘b’ = [0, 0, 0.1]
-
遗忘门计算:
W f ⋅ [ h 2 , ′ b ′ ] = [ 0.02 , 0.027 , 0.1 , 0 ] W_f \cdot [h_2, 'b'] = [0.02, 0.027, 0.1, 0] Wf⋅[h2,′b′]=[0.02,0.027,0.1,0]
f 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] f_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] f3=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5] -
输入门计算:
i 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] i_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] i3=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5]
C ~ 3 = tanh ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.02 , 0.027 , 0.099 , 0 ] \tilde{C}_3 = \tanh([0.02, 0.027, 0.1, 0]) = [0.02, 0.027, 0.099, 0] C~3=tanh([0.02,0.027,0.1,0])=[0.02,0.027,0.099,0] -
细胞状态更新:
C 3 = f 3 ⊙ C 2 + i 3 ⊙ C ~ 3 C_3 = f_3 \odot C_2 + i_3 \odot \tilde{C}_3 C3=f3⊙C2+i3⊙C~3
= [ 0.02 , 0.026 , 0 , 0 ] + [ 0.01 , 0.014 , 0.052 , 0 ] = [ 0.03 , 0.04 , 0.052 , 0 ] = [0.02, 0.026, 0, 0] + [0.01, 0.014, 0.052, 0] = [0.03, 0.04, 0.052, 0] =[0.02,0.026,0,0]+[0.01,0.014,0.052,0]=[0.03,0.04,0.052,0] -
输出门与隐藏状态:
o 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] o_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] o3=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5]
h 3 = o 3 ⊙ tanh ( C 3 ) = [ 0.015 , 0.02 , 0.027 , 0 ] h_3 = o_3 \odot \tanh(C_3) = [0.015, 0.02, 0.027, 0] h3=o3⊙tanh(C3)=[0.015,0.02,0.027,0]
最终结果对比
模型 | "Bob"的特征表示(最终隐藏状态) |
---|---|
RNN | [0, 0, 0.099, 0.292] |
LSTM | [0.015, 0.02, 0.027, 0] |
关键差异分析:
-
信息保留方式:
- RNN直接累加历史信息,导致后期输入权重过大(如’b’的影响占主导)
- LSTM通过门控机制平衡了各字符的影响,保留了更均衡的特征表示
-
梯度传递能力:
- RNN的梯度依赖 tanh \tanh tanh导数(最大值为1),易衰减
- LSTM的细胞状态通过 f t f_t ft(接近1)传递梯度,避免消失
3.3 LSTM在Pytorch中的API
- LSTM类定义与核心参数
torch.nn.LSTM(
input_size, # 输入特征维度
hidden_size, # 隐藏状态维度
num_layers=1, # 堆叠的LSTM层数
bias=True, # 是否使用偏置
batch_first=False, # 输入格式是否为(batch, seq, feature)
dropout=0, # Dropout概率
bidirectional=False # 是否为双向LSTM
)
-
输入与输出格式
输入参数:
- input:输入序列,形状为
(seq_len, batch, input_size)
(默认)或(batch, seq_len, input_size)
(batch_first=True
) - h_0:初始隐藏状态,形状为
(num_layers * num_directions, batch, hidden_size)
- c_0:初始细胞状态,形状为
(num_layers * num_directions, batch, hidden_size)
输出参数:
- output:所有时间步的隐藏状态,形状为
(seq_len, batch, hidden_size * num_directions)
- h_n:最后一个时间步的隐藏状态,形状同
h_0
- c_n:最后一个时间步的细胞状态,形状同
c_0
- input:输入序列,形状为
-
关键属性与方法
权重矩阵:
- weight_ih_l[k]:第
k
层的输入到隐藏的权重(4个门合并) - weight_hh_l[k]:第
k
层的隐藏到隐藏的权重(4个门合并) - bias_ih_l[k] 和 bias_hh_l[k]:对应偏置
- weight_ih_l[k]:第
-
前向传播方法
output, (h_n, c_n) = lstm(input, (h_0, c_0))
3.4 代码示例
- 基本用法
import torch
import torch.nn as nn
# 创建LSTM模型
lstm = nn.LSTM(
input_size=10, # 输入特征维度
hidden_size=20, # 隐藏状态维度
num_layers=2, # 2层LSTM堆叠
batch_first=True, # 使用(batch, seq, feature)格式
bidirectional=True # 双向LSTM
)
# 准备输入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10) # 输入序列
# 初始化隐藏状态和细胞状态(可选)
h0 = torch.zeros(2*2, batch_size, 20) # 2层 * 双向
c0 = torch.zeros(2*2, batch_size, 20)
# 前向传播
output, (hn, cn) = lstm(x, (h0, c0))
# 输出形状分析
print("Output shape:", output.shape) # (3, 5, 40) [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape) # (4, 3, 20) [layers*directions, batch, hidden]
print("Final cell shape:", cn.shape) # (4, 3, 20)
- 获取最后时间步的隐藏状态
# 方法1:从output中获取
last_output = output[:, -1, :] # (batch, hidden*directions)
# 方法2:从hn中获取
last_hidden = hn[-2:] if lstm.bidirectional else hn[-1] # 双向时需拼接两个方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if lstm.bidirectional else last_hidden
3.5 门控机制的数学本质
LSTM通过线性组合(
C
t
=
f
t
C
t
−
1
+
i
t
C
~
t
C_t = f_tC_{t-1} + i_t\tilde{C}_t
Ct=ftCt−1+itC~t)实现梯度的"直连"传播,避免了传统RNN的连乘衰减,数学上表现为:
∂
C
t
∂
C
t
−
1
=
f
t
\frac{\partial C_t}{\partial C_{t-1}} = f_t
∂Ct−1∂Ct=ft
当
f
t
f_t
ft接近1时,梯度可近乎无损地传递至远层,这是LSTM解决长期依赖的核心。
四、GRU:LSTM的轻量级进化
4.1 双门控简化结构
GRU将LSTM的四门结构简化为:
- 更新门:控制历史信息保留比例
z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz⋅[ht−1,xt]+bz) - 重置门:控制历史信息遗忘程度
r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr⋅[ht−1,xt]+br)
核心公式:
- 候选状态: h ~ t = tanh ( W ⋅ [ r t ⊙ h t − 1 , x t ] + b ) \tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t] + b) h~t=tanh(W⋅[rt⊙ht−1,xt]+b)
- 状态更新: h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1−zt)⊙ht−1+zt⊙h~t
4.2 "Bob"案例的GRU完整计算过程
为便于与RNN和LSTM对比,我们保持相同的输入维度、隐藏状态维度,并使用相似的参数设置:
假设条件:
- 输入维度=3,隐藏状态维度=4
- 字符编码:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
- GRU参数(简化后):
(注:每个权重矩阵实际为4x7,因拼接h_{t-1}(4维)和x_t(3维))W_z = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_r = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]] W_h = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
详细计算过程:
时间步1:输入 ‘B’ = [0.1, 0, 0]
-
更新门计算:
z 1 = σ ( W z ⋅ [ h 0 , ′ B ′ ] + b z ) z_1 = \sigma(W_z \cdot [h_0, 'B'] + b_z) z1=σ(Wz⋅[h0,′B′]+bz)
假设 b z = [ 0 , 0 , 0 , 0 ] b_z=[0,0,0,0] bz=[0,0,0,0],则:
W z ⋅ [ h 0 , ′ B ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0 , 0 , 0 , 0 , 0.1 , 0 , 0 ] T = [ 0.1 , 0 , 0 , 0 ] W_z \cdot [h_0, 'B'] = [[1,0,0,1,0,0,0], ...] \cdot [0,0,0,0,0.1,0,0]^T = [0.1, 0, 0, 0] Wz⋅[h0,′B′]=[[1,0,0,1,0,0,0],...]⋅[0,0,0,0,0.1,0,0]T=[0.1,0,0,0]
z 1 = σ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] z_1 = \sigma([0.1, 0, 0, 0]) = [0.525, 0.5, 0.5, 0.5] z1=σ([0.1,0,0,0])=[0.525,0.5,0.5,0.5] -
重置门计算:
r 1 = σ ( W r ⋅ [ h 0 , ′ B ′ ] + b r ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] r_1 = \sigma(W_r \cdot [h_0, 'B'] + b_r) = [0.525, 0.5, 0.5, 0.5] r1=σ(Wr⋅[h0,′B′]+br)=[0.525,0.5,0.5,0.5] -
候选隐藏状态:
h ~ 1 = tanh ( W h ⋅ [ r 1 ⊙ h 0 , ′ B ′ ] + b h ) \tilde{h}_1 = \tanh(W_h \cdot [r_1 \odot h_0, 'B'] + b_h) h~1=tanh(Wh⋅[r1⊙h0,′B′]+bh)
= tanh ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.099 , 0 , 0 , 0 ] = \tanh([0.1, 0, 0, 0]) = [0.099, 0, 0, 0] =tanh([0.1,0,0,0])=[0.099,0,0,0] -
最终隐藏状态:
h 1 = ( 1 − z 1 ) ⊙ h 0 + z 1 ⊙ h ~ 1 h_1 = (1-z_1) \odot h_0 + z_1 \odot \tilde{h}_1 h1=(1−z1)⊙h0+z1⊙h~1
= [ 0.475 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0 , 0 , 0 , 0 ] + [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.099 , 0 , 0 , 0 ] = [0.475, 0.5, 0.5, 0.5] \odot [0,0,0,0] + [0.525, 0.5, 0.5, 0.5] \odot [0.099, 0, 0, 0] =[0.475,0.5,0.5,0.5]⊙[0,0,0,0]+[0.525,0.5,0.5,0.5]⊙[0.099,0,0,0]
= [ 0.052 , 0 , 0 , 0 ] = [0.052, 0, 0, 0] =[0.052,0,0,0]
时间步2:输入 ‘o’ = [0, 0.1, 0]
-
更新门计算:
W z ⋅ [ h 1 , ′ o ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0.052 , 0 , 0 , 0 , 0 , 0.1 , 0 ] T = [ 0.052 , 0.1 , 0 , 0 ] W_z \cdot [h_1, 'o'] = [[1,0,0,1,0,0,0], ...] \cdot [0.052,0,0,0,0,0.1,0]^T = [0.052, 0.1, 0, 0] Wz⋅[h1,′o′]=[[1,0,0,1,0,0,0],...]⋅[0.052,0,0,0,0,0.1,0]T=[0.052,0.1,0,0]
z 2 = σ ( [ 0.052 , 0.1 , 0 , 0 ] ) = [ 0.513 , 0.525 , 0.5 , 0.5 ] z_2 = \sigma([0.052, 0.1, 0, 0]) = [0.513, 0.525, 0.5, 0.5] z2=σ([0.052,0.1,0,0])=[0.513,0.525,0.5,0.5] -
重置门计算:
r 2 = σ ( [ 0.052 , 0.1 , 0 , 0 ] ) = [ 0.513 , 0.525 , 0.5 , 0.5 ] r_2 = \sigma([0.052, 0.1, 0, 0]) = [0.513, 0.525, 0.5, 0.5] r2=σ([0.052,0.1,0,0])=[0.513,0.525,0.5,0.5] -
候选隐藏状态:
h ~ 2 = tanh ( W h ⋅ [ r 2 ⊙ h 1 , ′ o ′ ] + b h ) \tilde{h}_2 = \tanh(W_h \cdot [r_2 \odot h_1, 'o'] + b_h) h~2=tanh(Wh⋅[r2⊙h1,′o′]+bh)
= tanh ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.027 , 0.099 , 0 , 0 ] = \tanh([0.027, 0.1, 0, 0]) = [0.027, 0.099, 0, 0] =tanh([0.027,0.1,0,0])=[0.027,0.099,0,0] -
最终隐藏状态:
h 2 = ( 1 − z 2 ) ⊙ h 1 + z 2 ⊙ h ~ 2 h_2 = (1-z_2) \odot h_1 + z_2 \odot \tilde{h}_2 h2=(1−z2)⊙h1+z2⊙h~2
= [ 0.026 , 0 , 0 , 0 ] + [ 0.014 , 0.05 , 0 , 0 ] = [ 0.04 , 0.05 , 0 , 0 ] = [0.026, 0, 0, 0] + [0.014, 0.05, 0, 0] = [0.04, 0.05, 0, 0] =[0.026,0,0,0]+[0.014,0.05,0,0]=[0.04,0.05,0,0]
时间步3:输入 ‘b’ = [0, 0, 0.1]
-
更新门计算:
W z ⋅ [ h 2 , ′ b ′ ] = [ 0.04 , 0.05 , 0.1 , 0 ] W_z \cdot [h_2, 'b'] = [0.04, 0.05, 0.1, 0] Wz⋅[h2,′b′]=[0.04,0.05,0.1,0]
z 3 = σ ( [ 0.04 , 0.05 , 0.1 , 0 ] ) = [ 0.51 , 0.512 , 0.525 , 0.5 ] z_3 = \sigma([0.04, 0.05, 0.1, 0]) = [0.51, 0.512, 0.525, 0.5] z3=σ([0.04,0.05,0.1,0])=[0.51,0.512,0.525,0.5] -
重置门计算:
r 3 = σ ( [ 0.04 , 0.05 , 0.1 , 0 ] ) = [ 0.51 , 0.512 , 0.525 , 0.5 ] r_3 = \sigma([0.04, 0.05, 0.1, 0]) = [0.51, 0.512, 0.525, 0.5] r3=σ([0.04,0.05,0.1,0])=[0.51,0.512,0.525,0.5] -
候选隐藏状态:
h ~ 3 = tanh ( W h ⋅ [ r 3 ⊙ h 2 , ′ b ′ ] + b h ) \tilde{h}_3 = \tanh(W_h \cdot [r_3 \odot h_2, 'b'] + b_h) h~3=tanh(Wh⋅[r3⊙h2,′b′]+bh)
= tanh ( [ 0.02 , 0.026 , 0.1 , 0 ] ) = [ 0.02 , 0.026 , 0.099 , 0 ] = \tanh([0.02, 0.026, 0.1, 0]) = [0.02, 0.026, 0.099, 0] =tanh([0.02,0.026,0.1,0])=[0.02,0.026,0.099,0] -
最终隐藏状态:
h 3 = ( 1 − z 3 ) ⊙ h 2 + z 3 ⊙ h ~ 3 h_3 = (1-z_3) \odot h_2 + z_3 \odot \tilde{h}_3 h3=(1−z3)⊙h2+z3⊙h~3
= [ 0.02 , 0.025 , 0 , 0 ] + [ 0.01 , 0.013 , 0.052 , 0 ] = [ 0.03 , 0.038 , 0.052 , 0 ] = [0.02, 0.025, 0, 0] + [0.01, 0.013, 0.052, 0] = [0.03, 0.038, 0.052, 0] =[0.02,0.025,0,0]+[0.01,0.013,0.052,0]=[0.03,0.038,0.052,0]
三种模型的最终特征表示对比
模型 | "Bob"的特征表示(最终隐藏状态) |
---|---|
RNN | [0, 0, 0.099, 0.292] |
LSTM | [0.015, 0.02, 0.027, 0] |
GRU | [0.03, 0.038, 0.052, 0] |
关键差异分析:
-
信息融合方式:
- RNN直接累加输入,导致后期信息主导
- LSTM通过细胞状态长期记忆,但计算复杂
- GRU通过更新门动态平衡历史与当前信息,计算效率更高
-
参数效率:
- GRU参数量约为LSTM的2/3,训练速度更快
- 在短序列任务中,GRU通常能达到与LSTM接近的性能
4.3 GRU在Pytorch中的API
- GRU类定义与核心参数
torch.nn.GRU(
input_size, # 输入特征维度
hidden_size, # 隐藏状态维度
num_layers=1, # 堆叠的GRU层数
bias=True, # 是否使用偏置
batch_first=False, # 输入格式是否为(batch, seq, feature)
dropout=0, # Dropout概率
bidirectional=False # 是否为双向GRU
)
-
输入与输出格式
输入参数:- input:输入序列,形状为
(seq_len, batch, input_size)
(默认)或(batch, seq_len, input_size)
(batch_first=True
) - h_0:初始隐藏状态,形状为
(num_layers * num_directions, batch, hidden_size)
输出参数:
- output:所有时间步的隐藏状态,形状为
(seq_len, batch, hidden_size * num_directions)
- h_n:最后一个时间步的隐藏状态,形状同
h_0
- input:输入序列,形状为
-
关键属性与方法
权重矩阵:- weight_ih_l[k]:第
k
层的输入到隐藏的权重(重置门和更新门合并) - weight_hh_l[k]:第
k
层的隐藏到隐藏的权重 - bias_ih_l[k] 和 bias_hh_l[k]:对应偏置
前向传播方法:
- weight_ih_l[k]:第
output, h_n = gru(input, h_0) # 与LSTM相比,少了细胞状态c_n
3.4 代码示例
- 基本用法
import torch
import torch.nn as nn
# 创建GRU模型
gru = nn.GRU(
input_size=10, # 输入特征维度
hidden_size=20, # 隐藏状态维度
num_layers=2, # 2层GRU堆叠
batch_first=True, # 使用(batch, seq, feature)格式
bidirectional=True # 双向GRU
)
# 准备输入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10) # 输入序列
# 初始化隐藏状态(可选)
h0 = torch.zeros(2*2, batch_size, 20) # 2层 * 双向
# 前向传播
output, hn = gru(x, h0)
# 输出形状分析
print("Output shape:", output.shape) # (3, 5, 40) [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape) # (4, 3, 20) [layers*directions, batch, hidden]
- 获取最后时间步的隐藏状态
# 方法1:从output中获取
last_output = output[:, -1, :] # (batch, hidden*directions)
# 方法2:从hn中获取
last_hidden = hn[-2:] if gru.bidirectional else hn[-1] # 双向时需拼接两个方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if gru.bidirectional else last_hidden
五、三大模型的对比与实践选择
5.1 核心指标对比
模型 | 门控数量 | 参数量(输入n→隐藏m) | 长期依赖能力 | 计算效率 |
---|---|---|---|---|
传统RNN | 0 | nm + mm | 差 | 高 |
LSTM | 4 | 4*(nm + mm) | 优 | 低 |
GRU | 2 | 3*(nm + mm) | 良 | 中 |
5.2 适用场景建议
-
传统RNN:
短序列任务(如长度<20的文本分类)、计算资源严格受限场景 -
LSTM:
长序列建模(机器翻译、语音识别)、对精度要求高的任务 -
GRU:
中等长度序列(如对话系统、时间序列预测)、希望平衡精度与效率的场景
5.3 要点
- 梯度处理:
- LSTM/GRU天然缓解梯度消失,但仍需配合梯度裁剪(gradient clipping)防止爆炸
- 参数初始化:
- 传统RNN需谨慎初始化权重以避免梯度问题
- 双向与多层:
- 双向结构可捕捉双向依赖,多层网络提升特征提取能力,但会显著增加计算量
循环神经网络的进化史是模型表达能力与计算效率的平衡艺术。从RNN到LSTM再到GRU,每一次改进都围绕"如何更高效地建模序列依赖"展开。在实际应用中,应根据数据长度、计算资源和任务精度要求,选择最适合的模型架构。