RNN(Recurrent Neural Network,循环神经网络)家族详解(RNN,LSTM,GRU)

循环神经网络(RNN)作为处理序列数据的核心模型,在自然语言处理、时间序列分析等领域发挥着关键作用。本文将系统梳理传统RNN、LSTM和GRU的内部机制、数学表达与实践应用,通过统一案例对比三者的计算过程,帮助读者深入理解序列模型的进化脉络。

一、RNN基础:序列建模的核心思想

在这里插入图片描述

1.1 RNN的本质与核心机制

RNN(Recurrent Neural Network)的核心创新在于"循环记忆"机制——将上一时间步的隐藏状态与当前输入结合,形成对序列依赖关系的建模能力。其本质是通过参数共享实现对时间维度的特征提取,数学上表现为:

h t = f ( W ⋅ [ x t , h t − 1 ] + b ) h_t = f(W \cdot [x_t, h_{t-1}] + b) ht=f(W[xt,ht1]+b)

  • x t x_t xt:当前时间步输入向量
  • h t − 1 h_{t-1} ht1:上一时间步隐藏状态
  • f f f:激活函数(通常为tanh或sigmoid)

这种结构使得RNN能够捕捉序列中的短期依赖,例如判断"我爱吃苹果"中"苹果"的词性需依赖前文"吃"的动作。

1.2 应用场景与结构分类

RNN按输入输出结构可分为四类:

  • N vs N:输入输出等长,适用于语音识别中的帧级标注
  • N vs 1:序列输入→单一输出,典型如文本分类
  • 1 vs N:单一输入→序列输出,常用于图片生成描述
  • N vs M:seq2seq架构,输入输出长度不限,是机器翻译的基础

二、传统RNN:序列模型的起点

结构图
在这里插入图片描述 在这里插入图片描述

2.1 内部结构与数学表达

传统RNN的计算流程可拆解为:

  1. 输入拼接: [ x t , h t − 1 ] [x_t, h_{t-1}] [xt,ht1]
  2. 线性变换: W ⋅ [ x t , h t − 1 ] + b W \cdot [x_t, h_{t-1}] + b W[xt,ht1]+b
  3. 激活输出: h t = tanh ⁡ ( ⋅ ) h_t = \tanh(\cdot) ht=tanh()

以一个3维输入、4维隐藏状态的RNN为例,其参数矩阵为:

  • 输入权重 W i h ∈ R 4 × 3 W_{ih} \in \mathbb{R}^{4 \times 3} WihR4×3
  • 隐藏权重 W h h ∈ R 4 × 4 W_{hh} \in \mathbb{R}^{4 \times 4} WhhR4×4
  • 偏置 b ∈ R 4 b \in \mathbb{R}^4 bR4

2.2 计算示例

人名"Bob"的特征提取
假设:

  • 字符编码:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
  • RNN参数: W i h = [ [ 1 , 0 , 0 ] , [ 0 , 1 , 0 ] , [ 0 , 0 , 1 ] , [ 1 , 1 , 1 ] ] W_{ih} = [[1,0,0],[0,1,0],[0,0,1],[1,1,1]] Wih=[[1,0,0],[0,1,0],[0,0,1],[1,1,1]](简化示例)
  • W h h = [ [ 1 , 0 , 0 , 0 ] , [ 0 , 1 , 0 , 0 ] , [ 0 , 0 , 1 , 0 ] , [ 0 , 0 , 0 , 1 ] ] W_{hh} = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] Whh=[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]

计算步骤:

  1. 时间步1:输入’B’
    h 1 = tanh ⁡ ( W i h ⋅ ′ B ′ + W h h ⋅ h 0 ) = tanh ⁡ ( [ 0.1 , 0 , 0 , 0.1 ] ) h_1 = \tanh(W_{ih} \cdot 'B' + W_{hh} \cdot h_0) = \tanh([0.1, 0, 0, 0.1]) h1=tanh(WihB+Whhh0)=tanh([0.1,0,0,0.1])
    h 1 = [ 0.099 , 0 , 0 , 0.099 ] h_1 = [0.099, 0, 0, 0.099] h1=[0.099,0,0,0.099](假设h0全0)

  2. 时间步2:输入’o’
    h 2 = tanh ⁡ ( W i h ⋅ ′ o ′ + W h h ⋅ h 1 ) h_2 = \tanh(W_{ih} \cdot 'o' + W_{hh} \cdot h_1) h2=tanh(Wiho+Whhh1)
    = tanh ⁡ ( [ 0 , 0.1 , 0 , 0.1 + 0.099 ] ) = [ 0 , 0.099 , 0 , 0.197 ] = \tanh([0, 0.1, 0, 0.1+0.099]) = [0, 0.099, 0, 0.197] =tanh([0,0.1,0,0.1+0.099])=[0,0.099,0,0.197]

  3. 时间步3:输入’b’
    h 3 = tanh ⁡ ( [ 0 , 0 , 0.1 , 0.1 + 0.197 ] ) = [ 0 , 0 , 0.099 , 0.292 ] h_3 = \tanh([0, 0, 0.1, 0.1+0.197]) = [0, 0, 0.099, 0.292] h3=tanh([0,0,0.1,0.1+0.197])=[0,0,0.099,0.292]

最终隐藏状态 h 3 h_3 h3已经包含"Bob"的所有信息,即为"Bob"的序列特征表示。

2.3 RNN在Pytorch中的API

  1. RNN类定义与核心参数
torch.nn.RNN(
    input_size,           # 输入特征维度
    hidden_size,          # 隐藏状态维度
    num_layers=1,         # 堆叠的RNN层数
    nonlinearity='tanh',  # 非线性激活函数 'tanh' 或 'relu'
    bias=True,            # 是否使用偏置
    batch_first=False,    # 输入格式是否为(batch, seq, feature)
    dropout=0,            # Dropout概率
    bidirectional=False   # 是否为双向RNN
)
  1. 输入与输出格式

    输入参数:

    • input:输入序列,形状为(seq_len, batch, input_size)(默认)或(batch, seq_len, input_size)batch_first=True
    • h_0:初始隐藏状态,形状为(num_layers * num_directions, batch, hidden_size)

    输出参数:

    • output:所有时间步的隐藏状态,形状为(seq_len, batch, hidden_size * num_directions)
    • h_n:最后一个时间步的隐藏状态,形状同h_0
  2. 关键属性与方法

    权重矩阵:

    • weight_ih_l[k]:第k层的输入到隐藏的权重
    • weight_hh_l[k]:第k层的隐藏到隐藏的权重
    • bias_ih_l[k]bias_hh_l[k]:对应偏置

    前向传播方法:

output, h_n = rnn(input, h_0)

2.4 代码示例

  1. 基本用法
import torch
import torch.nn as nn

# 创建RNN模型
rnn = nn.RNN(
    input_size=10,      # 输入特征维度
    hidden_size=20,     # 隐藏状态维度
    num_layers=2,       # 2层RNN堆叠
    batch_first=True,   # 使用(batch, seq, feature)格式
    bidirectional=True  # 双向RNN
)

# 准备输入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10)  # 输入序列

# 初始化隐藏状态(可选)
h0 = torch.zeros(2*2, batch_size, 20)  # 2层 * 双向

# 前向传播
output, hn = rnn(x, h0)

# 输出形状分析
print("Output shape:", output.shape)      # (3, 5, 40)  [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape)    # (4, 3, 20)  [layers*directions, batch, hidden]
  1. 获取最后时间步的隐藏状态
# 方法1:从output中获取
last_output = output[:, -1, :]  # (batch, hidden*directions)

# 方法2:从hn中获取
last_hidden = hn[-2:] if rnn.bidirectional else hn[-1]  # 双向时需拼接两个方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if rnn.bidirectional else last_hidden

2.5 优缺点与梯度问题

  • 优势:结构简单,参数量少,短序列计算效率高
  • 致命缺陷:长序列中梯度消失严重,如:
    梯度连乘公式: ∇ = ∏ i = 1 n σ ′ ( z i ) ⋅ w i \nabla = \prod_{i=1}^n \sigma'(z_i) \cdot w_i =i=1nσ(zi)wi
    w i < 1 w_i < 1 wi<1时,连乘导致梯度趋近于0,无法更新远层参数

三、LSTM:门控机制破解长期依赖

在这里插入图片描述
在这里插入图片描述

3.1 四大门控机制详解

LSTM通过引入门控系统,将传统RNN的单一隐藏状态拆分为:

  • 细胞状态C:长期记忆载体
  • 隐藏状态h:短期特征表示

核心公式组

  1. 遗忘门:决定丢弃历史信息

    • 功能:决定丢弃细胞状态中的哪些历史信息。
    • 计算过程
      • 输入当前输入 x t x_t xt 和上一时刻隐藏状态 h t − 1 h_{t-1} ht1,拼接后通过全连接层
      • f t f_t ft 是0到1之间的门值,1表示“完全保留”,0表示“完全遗忘”。

    在这里插入图片描述

  2. 输入门:筛选新信息

    • 功能:决定当前输入的新信息中哪些需要存储到细胞状态。
    • 计算过程
      • 生成输入门门值 i t i_t it(类似遗忘门,通过sigmoid激活):
      • 生成候选细胞状态 C ~ t \tilde{C}_t C~t(通过tanh激活):
        在这里插入图片描述
  3. 细胞状态更新

    • 功能:存储长期记忆,通过门控机制更新。
    • 更新过程
      • f t ∗ C t − 1 f_t * C_{t-1} ftCt1:遗忘门作用于旧细胞状态,丢弃部分历史信息;
      • i t ∗ C ~ t i_t *\tilde{C}_t itC~t:输入门筛选新信息并与候选状态结合。

    在这里插入图片描述

  4. 输出门:生成当前隐藏状态

    • 功能:决定细胞状态中的哪些信息作为当前隐藏状态输出。
    • 计算过程
      • 生成输出门门值 o t o_t ot
      • 细胞状态通过tanh激活后,与输出门值相乘得到隐藏状态
        在这里插入图片描述

3.2 "Bob"案例的LSTM完整计算示例

为便于与RNN对比,我们保持输入维度、隐藏状态维度一致,并使用相似的参数设置:

假设条件:

  • 输入维度=3,隐藏状态维度=4
  • 字符编码:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
  • LSTM参数(简化后):
    W_f = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
    W_i = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
    W_c = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
    W_o = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
    
    (注:每个权重矩阵实际为4x7,因拼接h_{t-1}(4维)和x_t(3维))

详细计算过程:

时间步1:输入 ‘B’ = [0.1, 0, 0]

  1. 遗忘门计算:
    f 1 = σ ( W f ⋅ [ h 0 , ′ B ′ ] + b f ) f_1 = \sigma(W_f \cdot [h_0, 'B'] + b_f) f1=σ(Wf[h0,B]+bf)
    假设 b f = [ 0 , 0 , 0 , 0 ] b_f=[0,0,0,0] bf=[0,0,0,0],则:
    W f ⋅ [ h 0 , ′ B ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0 , 0 , 0 , 0 , 0.1 , 0 , 0 ] T = [ 0.1 , 0 , 0 , 0 ] W_f \cdot [h_0, 'B'] = [[1,0,0,1,0,0,0], ...] \cdot [0,0,0,0,0.1,0,0]^T = [0.1, 0, 0, 0] Wf[h0,B]=[[1,0,0,1,0,0,0],...][0,0,0,0,0.1,0,0]T=[0.1,0,0,0]
    f 1 = σ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] f_1 = \sigma([0.1, 0, 0, 0]) = [0.525, 0.5, 0.5, 0.5] f1=σ([0.1,0,0,0])=[0.525,0.5,0.5,0.5]

  2. 输入门计算:
    i 1 = σ ( W i ⋅ [ h 0 , ′ B ′ ] + b i ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] i_1 = \sigma(W_i \cdot [h_0, 'B'] + b_i) = [0.525, 0.5, 0.5, 0.5] i1=σ(Wi[h0,B]+bi)=[0.525,0.5,0.5,0.5]
    C ~ 1 = tanh ⁡ ( W c ⋅ [ h 0 , ′ B ′ ] + b c ) = tanh ⁡ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.099 , 0 , 0 , 0 ] \tilde{C}_1 = \tanh(W_c \cdot [h_0, 'B'] + b_c) = \tanh([0.1, 0, 0, 0]) = [0.099, 0, 0, 0] C~1=tanh(Wc[h0,B]+bc)=tanh([0.1,0,0,0])=[0.099,0,0,0]

  3. 细胞状态更新:
    C 1 = f 1 ⊙ C 0 + i 1 ⊙ C ~ 1 C_1 = f_1 \odot C_0 + i_1 \odot \tilde{C}_1 C1=f1C0+i1C~1
    = [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0 , 0 , 0 , 0 ] + [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.099 , 0 , 0 , 0 ] = [0.525, 0.5, 0.5, 0.5] \odot [0,0,0,0] + [0.525, 0.5, 0.5, 0.5] \odot [0.099, 0, 0, 0] =[0.525,0.5,0.5,0.5][0,0,0,0]+[0.525,0.5,0.5,0.5][0.099,0,0,0]
    = [ 0.052 , 0 , 0 , 0 ] = [0.052, 0, 0, 0] =[0.052,0,0,0]

  4. 输出门与隐藏状态:
    o 1 = σ ( W o ⋅ [ h 0 , ′ B ′ ] + b o ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] o_1 = \sigma(W_o \cdot [h_0, 'B'] + b_o) = [0.525, 0.5, 0.5, 0.5] o1=σ(Wo[h0,B]+bo)=[0.525,0.5,0.5,0.5]
    h 1 = o 1 ⊙ tanh ⁡ ( C 1 ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.052 , 0 , 0 , 0 ] = [ 0.027 , 0 , 0 , 0 ] h_1 = o_1 \odot \tanh(C_1) = [0.525, 0.5, 0.5, 0.5] \odot [0.052, 0, 0, 0] = [0.027, 0, 0, 0] h1=o1tanh(C1)=[0.525,0.5,0.5,0.5][0.052,0,0,0]=[0.027,0,0,0]

时间步2:输入 ‘o’ = [0, 0.1, 0]

  1. 遗忘门计算:
    W f ⋅ [ h 1 , ′ o ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0.027 , 0 , 0 , 0 , 0 , 0.1 , 0 ] T = [ 0.027 , 0.1 , 0 , 0 ] W_f \cdot [h_1, 'o'] = [[1,0,0,1,0,0,0], ...] \cdot [0.027,0,0,0,0,0.1,0]^T = [0.027, 0.1, 0, 0] Wf[h1,o]=[[1,0,0,1,0,0,0],...][0.027,0,0,0,0,0.1,0]T=[0.027,0.1,0,0]
    f 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] f_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] f2=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5]

  2. 输入门计算:
    i 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] i_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] i2=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5]
    C ~ 2 = tanh ⁡ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.027 , 0.099 , 0 , 0 ] \tilde{C}_2 = \tanh([0.027, 0.1, 0, 0]) = [0.027, 0.099, 0, 0] C~2=tanh([0.027,0.1,0,0])=[0.027,0.099,0,0]

  3. 细胞状态更新:
    C 2 = f 2 ⊙ C 1 + i 2 ⊙ C ~ 2 C_2 = f_2 \odot C_1 + i_2 \odot \tilde{C}_2 C2=f2C1+i2C~2
    = [ 0.026 , 0 , 0 , 0 ] + [ 0.014 , 0.052 , 0 , 0 ] = [ 0.04 , 0.052 , 0 , 0 ] = [0.026, 0, 0, 0] + [0.014, 0.052, 0, 0] = [0.04, 0.052, 0, 0] =[0.026,0,0,0]+[0.014,0.052,0,0]=[0.04,0.052,0,0]

  4. 输出门与隐藏状态:
    o 2 = σ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] o_2 = \sigma([0.027, 0.1, 0, 0]) = [0.507, 0.525, 0.5, 0.5] o2=σ([0.027,0.1,0,0])=[0.507,0.525,0.5,0.5]
    h 2 = o 2 ⊙ tanh ⁡ ( C 2 ) = [ 0.507 , 0.525 , 0.5 , 0.5 ] ⊙ [ 0.04 , 0.052 , 0 , 0 ] = [ 0.02 , 0.027 , 0 , 0 ] h_2 = o_2 \odot \tanh(C_2) = [0.507, 0.525, 0.5, 0.5] \odot [0.04, 0.052, 0, 0] = [0.02, 0.027, 0, 0] h2=o2tanh(C2)=[0.507,0.525,0.5,0.5][0.04,0.052,0,0]=[0.02,0.027,0,0]

时间步3:输入 ‘b’ = [0, 0, 0.1]

  1. 遗忘门计算:
    W f ⋅ [ h 2 , ′ b ′ ] = [ 0.02 , 0.027 , 0.1 , 0 ] W_f \cdot [h_2, 'b'] = [0.02, 0.027, 0.1, 0] Wf[h2,b]=[0.02,0.027,0.1,0]
    f 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] f_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] f3=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5]

  2. 输入门计算:
    i 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] i_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] i3=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5]
    C ~ 3 = tanh ⁡ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.02 , 0.027 , 0.099 , 0 ] \tilde{C}_3 = \tanh([0.02, 0.027, 0.1, 0]) = [0.02, 0.027, 0.099, 0] C~3=tanh([0.02,0.027,0.1,0])=[0.02,0.027,0.099,0]

  3. 细胞状态更新:
    C 3 = f 3 ⊙ C 2 + i 3 ⊙ C ~ 3 C_3 = f_3 \odot C_2 + i_3 \odot \tilde{C}_3 C3=f3C2+i3C~3
    = [ 0.02 , 0.026 , 0 , 0 ] + [ 0.01 , 0.014 , 0.052 , 0 ] = [ 0.03 , 0.04 , 0.052 , 0 ] = [0.02, 0.026, 0, 0] + [0.01, 0.014, 0.052, 0] = [0.03, 0.04, 0.052, 0] =[0.02,0.026,0,0]+[0.01,0.014,0.052,0]=[0.03,0.04,0.052,0]

  4. 输出门与隐藏状态:
    o 3 = σ ( [ 0.02 , 0.027 , 0.1 , 0 ] ) = [ 0.505 , 0.507 , 0.525 , 0.5 ] o_3 = \sigma([0.02, 0.027, 0.1, 0]) = [0.505, 0.507, 0.525, 0.5] o3=σ([0.02,0.027,0.1,0])=[0.505,0.507,0.525,0.5]
    h 3 = o 3 ⊙ tanh ⁡ ( C 3 ) = [ 0.015 , 0.02 , 0.027 , 0 ] h_3 = o_3 \odot \tanh(C_3) = [0.015, 0.02, 0.027, 0] h3=o3tanh(C3)=[0.015,0.02,0.027,0]

最终结果对比

模型"Bob"的特征表示(最终隐藏状态)
RNN[0, 0, 0.099, 0.292]
LSTM[0.015, 0.02, 0.027, 0]

关键差异分析:

  1. 信息保留方式

    • RNN直接累加历史信息,导致后期输入权重过大(如’b’的影响占主导)
    • LSTM通过门控机制平衡了各字符的影响,保留了更均衡的特征表示
  2. 梯度传递能力

    • RNN的梯度依赖 tanh ⁡ \tanh tanh导数(最大值为1),易衰减
    • LSTM的细胞状态通过 f t f_t ft(接近1)传递梯度,避免消失

3.3 LSTM在Pytorch中的API

  1. LSTM类定义与核心参数
torch.nn.LSTM(
    input_size,           # 输入特征维度
    hidden_size,          # 隐藏状态维度
    num_layers=1,         # 堆叠的LSTM层数
    bias=True,            # 是否使用偏置
    batch_first=False,    # 输入格式是否为(batch, seq, feature)
    dropout=0,            # Dropout概率
    bidirectional=False   # 是否为双向LSTM
)
  1. 输入与输出格式

    输入参数:

    • input:输入序列,形状为(seq_len, batch, input_size)(默认)或(batch, seq_len, input_size)batch_first=True
    • h_0:初始隐藏状态,形状为(num_layers * num_directions, batch, hidden_size)
    • c_0:初始细胞状态,形状为(num_layers * num_directions, batch, hidden_size)

    输出参数:

    • output:所有时间步的隐藏状态,形状为(seq_len, batch, hidden_size * num_directions)
    • h_n:最后一个时间步的隐藏状态,形状同h_0
    • c_n:最后一个时间步的细胞状态,形状同c_0
  2. 关键属性与方法

    权重矩阵:

    • weight_ih_l[k]:第k层的输入到隐藏的权重(4个门合并)
    • weight_hh_l[k]:第k层的隐藏到隐藏的权重(4个门合并)
    • bias_ih_l[k]bias_hh_l[k]:对应偏置
  3. 前向传播方法

output, (h_n, c_n) = lstm(input, (h_0, c_0))

3.4 代码示例

  1. 基本用法
import torch
import torch.nn as nn

# 创建LSTM模型
lstm = nn.LSTM(
    input_size=10,      # 输入特征维度
    hidden_size=20,     # 隐藏状态维度
    num_layers=2,       # 2层LSTM堆叠
    batch_first=True,   # 使用(batch, seq, feature)格式
    bidirectional=True  # 双向LSTM
)

# 准备输入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10)  # 输入序列

# 初始化隐藏状态和细胞状态(可选)
h0 = torch.zeros(2*2, batch_size, 20)  # 2层 * 双向
c0 = torch.zeros(2*2, batch_size, 20)

# 前向传播
output, (hn, cn) = lstm(x, (h0, c0))

# 输出形状分析
print("Output shape:", output.shape)      # (3, 5, 40)  [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape)    # (4, 3, 20)  [layers*directions, batch, hidden]
print("Final cell shape:", cn.shape)      # (4, 3, 20)
  1. 获取最后时间步的隐藏状态
# 方法1:从output中获取
last_output = output[:, -1, :]  # (batch, hidden*directions)

# 方法2:从hn中获取
last_hidden = hn[-2:] if lstm.bidirectional else hn[-1]  # 双向时需拼接两个方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if lstm.bidirectional else last_hidden

3.5 门控机制的数学本质

LSTM通过线性组合( C t = f t C t − 1 + i t C ~ t C_t = f_tC_{t-1} + i_t\tilde{C}_t Ct=ftCt1+itC~t)实现梯度的"直连"传播,避免了传统RNN的连乘衰减,数学上表现为:
∂ C t ∂ C t − 1 = f t \frac{\partial C_t}{\partial C_{t-1}} = f_t Ct1Ct=ft
f t f_t ft接近1时,梯度可近乎无损地传递至远层,这是LSTM解决长期依赖的核心。

四、GRU:LSTM的轻量级进化

在这里插入图片描述
在这里插入图片描述

4.1 双门控简化结构

GRU将LSTM的四门结构简化为:

  1. 更新门:控制历史信息保留比例
    z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz[ht1,xt]+bz)
  2. 重置门:控制历史信息遗忘程度
    r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

核心公式:

  • 候选状态: h ~ t = tanh ⁡ ( W ⋅ [ r t ⊙ h t − 1 , x t ] + b ) \tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t] + b) h~t=tanh(W[rtht1,xt]+b)
  • 状态更新: h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1zt)ht1+zth~t

4.2 "Bob"案例的GRU完整计算过程

为便于与RNN和LSTM对比,我们保持相同的输入维度、隐藏状态维度,并使用相似的参数设置:

假设条件:

  • 输入维度=3,隐藏状态维度=4
  • 字符编码:‘B’=[0.1,0,0], ‘o’=[0,0.1,0], ‘b’=[0,0,0.1]
  • GRU参数(简化后):
    W_z = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
    W_r = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
    W_h = [[1,0,0,1,0,0,0], [0,1,0,0,1,0,0], [0,0,1,0,0,1,0], [0,0,0,1,0,0,1]]
    
    (注:每个权重矩阵实际为4x7,因拼接h_{t-1}(4维)和x_t(3维))

详细计算过程:

时间步1:输入 ‘B’ = [0.1, 0, 0]

  1. 更新门计算:
    z 1 = σ ( W z ⋅ [ h 0 , ′ B ′ ] + b z ) z_1 = \sigma(W_z \cdot [h_0, 'B'] + b_z) z1=σ(Wz[h0,B]+bz)
    假设 b z = [ 0 , 0 , 0 , 0 ] b_z=[0,0,0,0] bz=[0,0,0,0],则:
    W z ⋅ [ h 0 , ′ B ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0 , 0 , 0 , 0 , 0.1 , 0 , 0 ] T = [ 0.1 , 0 , 0 , 0 ] W_z \cdot [h_0, 'B'] = [[1,0,0,1,0,0,0], ...] \cdot [0,0,0,0,0.1,0,0]^T = [0.1, 0, 0, 0] Wz[h0,B]=[[1,0,0,1,0,0,0],...][0,0,0,0,0.1,0,0]T=[0.1,0,0,0]
    z 1 = σ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] z_1 = \sigma([0.1, 0, 0, 0]) = [0.525, 0.5, 0.5, 0.5] z1=σ([0.1,0,0,0])=[0.525,0.5,0.5,0.5]

  2. 重置门计算:
    r 1 = σ ( W r ⋅ [ h 0 , ′ B ′ ] + b r ) = [ 0.525 , 0.5 , 0.5 , 0.5 ] r_1 = \sigma(W_r \cdot [h_0, 'B'] + b_r) = [0.525, 0.5, 0.5, 0.5] r1=σ(Wr[h0,B]+br)=[0.525,0.5,0.5,0.5]

  3. 候选隐藏状态:
    h ~ 1 = tanh ⁡ ( W h ⋅ [ r 1 ⊙ h 0 , ′ B ′ ] + b h ) \tilde{h}_1 = \tanh(W_h \cdot [r_1 \odot h_0, 'B'] + b_h) h~1=tanh(Wh[r1h0,B]+bh)
    = tanh ⁡ ( [ 0.1 , 0 , 0 , 0 ] ) = [ 0.099 , 0 , 0 , 0 ] = \tanh([0.1, 0, 0, 0]) = [0.099, 0, 0, 0] =tanh([0.1,0,0,0])=[0.099,0,0,0]

  4. 最终隐藏状态:
    h 1 = ( 1 − z 1 ) ⊙ h 0 + z 1 ⊙ h ~ 1 h_1 = (1-z_1) \odot h_0 + z_1 \odot \tilde{h}_1 h1=(1z1)h0+z1h~1
    = [ 0.475 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0 , 0 , 0 , 0 ] + [ 0.525 , 0.5 , 0.5 , 0.5 ] ⊙ [ 0.099 , 0 , 0 , 0 ] = [0.475, 0.5, 0.5, 0.5] \odot [0,0,0,0] + [0.525, 0.5, 0.5, 0.5] \odot [0.099, 0, 0, 0] =[0.475,0.5,0.5,0.5][0,0,0,0]+[0.525,0.5,0.5,0.5][0.099,0,0,0]
    = [ 0.052 , 0 , 0 , 0 ] = [0.052, 0, 0, 0] =[0.052,0,0,0]

时间步2:输入 ‘o’ = [0, 0.1, 0]

  1. 更新门计算:
    W z ⋅ [ h 1 , ′ o ′ ] = [ [ 1 , 0 , 0 , 1 , 0 , 0 , 0 ] , . . . ] ⋅ [ 0.052 , 0 , 0 , 0 , 0 , 0.1 , 0 ] T = [ 0.052 , 0.1 , 0 , 0 ] W_z \cdot [h_1, 'o'] = [[1,0,0,1,0,0,0], ...] \cdot [0.052,0,0,0,0,0.1,0]^T = [0.052, 0.1, 0, 0] Wz[h1,o]=[[1,0,0,1,0,0,0],...][0.052,0,0,0,0,0.1,0]T=[0.052,0.1,0,0]
    z 2 = σ ( [ 0.052 , 0.1 , 0 , 0 ] ) = [ 0.513 , 0.525 , 0.5 , 0.5 ] z_2 = \sigma([0.052, 0.1, 0, 0]) = [0.513, 0.525, 0.5, 0.5] z2=σ([0.052,0.1,0,0])=[0.513,0.525,0.5,0.5]

  2. 重置门计算:
    r 2 = σ ( [ 0.052 , 0.1 , 0 , 0 ] ) = [ 0.513 , 0.525 , 0.5 , 0.5 ] r_2 = \sigma([0.052, 0.1, 0, 0]) = [0.513, 0.525, 0.5, 0.5] r2=σ([0.052,0.1,0,0])=[0.513,0.525,0.5,0.5]

  3. 候选隐藏状态:
    h ~ 2 = tanh ⁡ ( W h ⋅ [ r 2 ⊙ h 1 , ′ o ′ ] + b h ) \tilde{h}_2 = \tanh(W_h \cdot [r_2 \odot h_1, 'o'] + b_h) h~2=tanh(Wh[r2h1,o]+bh)
    = tanh ⁡ ( [ 0.027 , 0.1 , 0 , 0 ] ) = [ 0.027 , 0.099 , 0 , 0 ] = \tanh([0.027, 0.1, 0, 0]) = [0.027, 0.099, 0, 0] =tanh([0.027,0.1,0,0])=[0.027,0.099,0,0]

  4. 最终隐藏状态:
    h 2 = ( 1 − z 2 ) ⊙ h 1 + z 2 ⊙ h ~ 2 h_2 = (1-z_2) \odot h_1 + z_2 \odot \tilde{h}_2 h2=(1z2)h1+z2h~2
    = [ 0.026 , 0 , 0 , 0 ] + [ 0.014 , 0.05 , 0 , 0 ] = [ 0.04 , 0.05 , 0 , 0 ] = [0.026, 0, 0, 0] + [0.014, 0.05, 0, 0] = [0.04, 0.05, 0, 0] =[0.026,0,0,0]+[0.014,0.05,0,0]=[0.04,0.05,0,0]

时间步3:输入 ‘b’ = [0, 0, 0.1]

  1. 更新门计算:
    W z ⋅ [ h 2 , ′ b ′ ] = [ 0.04 , 0.05 , 0.1 , 0 ] W_z \cdot [h_2, 'b'] = [0.04, 0.05, 0.1, 0] Wz[h2,b]=[0.04,0.05,0.1,0]
    z 3 = σ ( [ 0.04 , 0.05 , 0.1 , 0 ] ) = [ 0.51 , 0.512 , 0.525 , 0.5 ] z_3 = \sigma([0.04, 0.05, 0.1, 0]) = [0.51, 0.512, 0.525, 0.5] z3=σ([0.04,0.05,0.1,0])=[0.51,0.512,0.525,0.5]

  2. 重置门计算:
    r 3 = σ ( [ 0.04 , 0.05 , 0.1 , 0 ] ) = [ 0.51 , 0.512 , 0.525 , 0.5 ] r_3 = \sigma([0.04, 0.05, 0.1, 0]) = [0.51, 0.512, 0.525, 0.5] r3=σ([0.04,0.05,0.1,0])=[0.51,0.512,0.525,0.5]

  3. 候选隐藏状态:
    h ~ 3 = tanh ⁡ ( W h ⋅ [ r 3 ⊙ h 2 , ′ b ′ ] + b h ) \tilde{h}_3 = \tanh(W_h \cdot [r_3 \odot h_2, 'b'] + b_h) h~3=tanh(Wh[r3h2,b]+bh)
    = tanh ⁡ ( [ 0.02 , 0.026 , 0.1 , 0 ] ) = [ 0.02 , 0.026 , 0.099 , 0 ] = \tanh([0.02, 0.026, 0.1, 0]) = [0.02, 0.026, 0.099, 0] =tanh([0.02,0.026,0.1,0])=[0.02,0.026,0.099,0]

  4. 最终隐藏状态:
    h 3 = ( 1 − z 3 ) ⊙ h 2 + z 3 ⊙ h ~ 3 h_3 = (1-z_3) \odot h_2 + z_3 \odot \tilde{h}_3 h3=(1z3)h2+z3h~3
    = [ 0.02 , 0.025 , 0 , 0 ] + [ 0.01 , 0.013 , 0.052 , 0 ] = [ 0.03 , 0.038 , 0.052 , 0 ] = [0.02, 0.025, 0, 0] + [0.01, 0.013, 0.052, 0] = [0.03, 0.038, 0.052, 0] =[0.02,0.025,0,0]+[0.01,0.013,0.052,0]=[0.03,0.038,0.052,0]

三种模型的最终特征表示对比

模型"Bob"的特征表示(最终隐藏状态)
RNN[0, 0, 0.099, 0.292]
LSTM[0.015, 0.02, 0.027, 0]
GRU[0.03, 0.038, 0.052, 0]

关键差异分析:

  1. 信息融合方式

    • RNN直接累加输入,导致后期信息主导
    • LSTM通过细胞状态长期记忆,但计算复杂
    • GRU通过更新门动态平衡历史与当前信息,计算效率更高
  2. 参数效率

    • GRU参数量约为LSTM的2/3,训练速度更快
    • 在短序列任务中,GRU通常能达到与LSTM接近的性能

4.3 GRU在Pytorch中的API

  1. GRU类定义与核心参数
torch.nn.GRU(
    input_size,           # 输入特征维度
    hidden_size,          # 隐藏状态维度
    num_layers=1,         # 堆叠的GRU层数
    bias=True,            # 是否使用偏置
    batch_first=False,    # 输入格式是否为(batch, seq, feature)
    dropout=0,            # Dropout概率
    bidirectional=False   # 是否为双向GRU
)
  1. 输入与输出格式
    输入参数:

    • input:输入序列,形状为(seq_len, batch, input_size)(默认)或(batch, seq_len, input_size)batch_first=True
    • h_0:初始隐藏状态,形状为(num_layers * num_directions, batch, hidden_size)

    输出参数:

    • output:所有时间步的隐藏状态,形状为(seq_len, batch, hidden_size * num_directions)
    • h_n:最后一个时间步的隐藏状态,形状同h_0
  2. 关键属性与方法
    权重矩阵:

    • weight_ih_l[k]:第k层的输入到隐藏的权重(重置门和更新门合并)
    • weight_hh_l[k]:第k层的隐藏到隐藏的权重
    • bias_ih_l[k]bias_hh_l[k]:对应偏置

    前向传播方法:

output, h_n = gru(input, h_0)  # 与LSTM相比,少了细胞状态c_n

3.4 代码示例

  1. 基本用法
import torch
import torch.nn as nn

# 创建GRU模型
gru = nn.GRU(
    input_size=10,      # 输入特征维度
    hidden_size=20,     # 隐藏状态维度
    num_layers=2,       # 2层GRU堆叠
    batch_first=True,   # 使用(batch, seq, feature)格式
    bidirectional=True  # 双向GRU
)

# 准备输入
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, 10)  # 输入序列

# 初始化隐藏状态(可选)
h0 = torch.zeros(2*2, batch_size, 20)  # 2层 * 双向

# 前向传播
output, hn = gru(x, h0)

# 输出形状分析
print("Output shape:", output.shape)      # (3, 5, 40)  [batch, seq, hidden*2]
print("Final hidden shape:", hn.shape)    # (4, 3, 20)  [layers*directions, batch, hidden]
  1. 获取最后时间步的隐藏状态
# 方法1:从output中获取
last_output = output[:, -1, :]  # (batch, hidden*directions)

# 方法2:从hn中获取
last_hidden = hn[-2:] if gru.bidirectional else hn[-1]  # 双向时需拼接两个方向
last_hidden = torch.cat([last_hidden[0], last_hidden[1]], dim=1) if gru.bidirectional else last_hidden

五、三大模型的对比与实践选择

5.1 核心指标对比

模型门控数量参数量(输入n→隐藏m)长期依赖能力计算效率
传统RNN0nm + mm
LSTM44*(nm + mm)
GRU23*(nm + mm)

5.2 适用场景建议

  • 传统RNN
    短序列任务(如长度<20的文本分类)、计算资源严格受限场景

  • LSTM
    长序列建模(机器翻译、语音识别)、对精度要求高的任务

  • GRU
    中等长度序列(如对话系统、时间序列预测)、希望平衡精度与效率的场景

5.3 要点

  1. 梯度处理
    • LSTM/GRU天然缓解梯度消失,但仍需配合梯度裁剪(gradient clipping)防止爆炸
  2. 参数初始化
    • 传统RNN需谨慎初始化权重以避免梯度问题
  3. 双向与多层
    • 双向结构可捕捉双向依赖,多层网络提升特征提取能力,但会显著增加计算量

循环神经网络的进化史是模型表达能力与计算效率的平衡艺术。从RNN到LSTM再到GRU,每一次改进都围绕"如何更高效地建模序列依赖"展开。在实际应用中,应根据数据长度、计算资源和任务精度要求,选择最适合的模型架构。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值