torch.nn.LSTM
(长短期记忆网络)
LSTM
(Long Short-Term Memory)是 循环神经网络(RNN)的一种变种,专门设计来解决传统RNN在处理长序列时出现的 梯度消失 或 梯度爆炸 问题。LSTM通过引入 门控机制 来控制信息的流动,从而保持长时依赖关系。
在PyTorch中,torch.nn.LSTM
是一个用于构建和训练LSTM网络的模块。它是 torch.nn
中的一个重要层(Layer),支持多层堆叠、双向LSTM等特性。
1. 语法
torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True,
batch_first=False, dropout=0, bidirectional=False, proj_size=0)
2. 参数说明
参数 | 说明 |
---|---|
input_size | 每个时间步输入的特征维度 |
hidden_size | 隐藏层的特征维度 |
num_layers | 堆叠的 LSTM 层数(默认为 1) |
bias | 是否使用偏置项(默认为 True ) |
batch_first | True 时输入输出形状为 (batch_size, seq_len, features) |
dropout | 多层 LSTM 之间的 dropout 比例(默认为 0 ) |
bidirectional | True 创建双向 LSTM(默认为 False ) |
proj_size | 如果 > 0,则添加投影层,投影到维度 proj_size |
3. 返回值
LSTM
返回一个元组:
output, (h_n, c_n) = lstm(input, (h_0, c_0))
output
:输出张量,形状为(seq_len, batch_size, num_directions * hidden_size)
(batch_first=False
)(batch_size, seq_len, num_directions * hidden_size)
(batch_first=True
)- 如果
bidirectional=True
,num_directions=2
,否则为1
。
h_n
:最后一个时间步的隐藏状态,形状为
( n u m _ l a y e r s × n u m _ d i r e c t i o n s , b a t c h _ s i z e , h i d d e n _ s i z e ) (num\_layers \times num\_directions, batch\_size, hidden\_size) (num_layers×num_directions,batch_size,hidden_size)c_n
:最后一个时间步的细胞状态,形状同h_n
。
4. LSTM的工作原理
LSTM通过 门控机制 控制信息的流动和遗忘,解决了传统RNN的梯度消失问题。LSTM有三个主要的门:
- 遗忘门(Forget Gate):决定保留过去信息的多少。
- 输入门(Input Gate):决定当前输入信息的影响程度。
- 输出门(Output Gate):决定当前时刻的输出。
LSTM的更新公式如下:
f t = σ ( W f x t + U f h t − 1 + b f ) (遗忘门) i t = σ ( W i x t + U i h t − 1 + b i ) (输入门) c ~ t = tanh ( W c x t + U c h t − 1 + b c ) (候选单元状态) c t = f t ∘ c t − 1 + i t ∘ c ~ t (更新单元状态) o t = σ ( W o x t + U o h t − 1 + b o ) (输出门) h t = o t ∘ tanh ( c t ) (输出) \begin{aligned} f_t &= \sigma(W_f x_t + U_f h_{t-1} + b_f) \quad \text{(遗忘门)} \\ i_t &= \sigma(W_i x_t + U_i h_{t-1} + b_i) \quad \text{(输入门)} \\ \tilde{c}_t &= \tanh(W_c x_t + U_c h_{t-1} + b_c) \quad \text{(候选单元状态)} \\ c_t &= f_t \circ c_{t-1} + i_t \circ \tilde{c}_t \quad \text{(更新单元状态)} \\ o_t &= \sigma(W_o x_t + U_o h_{t-1} + b_o) \quad \text{(输出门)} \\ h_t &= o_t \circ \tanh(c_t) \quad \text{(输出)} \end{aligned} ftitc~tctotht=σ(Wfxt+Ufht−1+bf)(遗忘门)=σ(Wixt+Uiht−1+bi)(输入门)=tanh(Wcxt+Ucht−1+bc)(候选单元状态)=ft∘ct−1+it∘c~t(更新单元状态)=σ(Woxt+Uoht−1+bo)(输出门)=ot∘tanh(ct)(输出)
其中:
- f t f_t ft 是遗忘门。
- i t i_t it 是输入门。
- c ~ t \tilde{c}_t c~t 是候选单元状态。
- c t c_t ct 是单元状态。
- o t o_t ot 是输出门。
- h t h_t ht 是隐藏状态。
这些门通过学习得出,控制了信息的流动,保证了LSTM可以有效地捕捉长时间依赖。
5. 基本使用
5.1 单层 LSTM
import torch
import torch.nn as nn
# 定义 LSTM 参数
input_size = 10 # 每个时间步的输入维度
hidden_size = 20 # 隐藏层维度
seq_len = 5 # 序列长度
batch_size = 3 # 批次大小
# 创建 LSTM
lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
# 生成随机输入数据 (batch_size, seq_len, input_size)
x = torch.randn(batch_size, seq_len, input_size)
# 初始化隐藏状态和细胞状态
h0 = torch.zeros(1, batch_size, hidden_size) # (num_layers * num_directions, batch_size, hidden_size)
c0 = torch.zeros(1, batch_size, hidden_size)
# 前向传播
output, (hn, cn) = lstm(x, (h0, c0))
print("输出形状:", output.shape) # (batch_size, seq_len, hidden_size)
print("最后的隐藏状态形状:", hn.shape) # (num_layers, batch_size, hidden_size)
print("最后的细胞状态形状:", cn.shape) # (num_layers, batch_size, hidden_size)
(h0, c0)
如果不提供,默认是0
5.2 堆叠多层 LSTM
lstm = nn.LSTM(input_size, hidden_size, num_layers=3, batch_first=True)
# 初始化多层隐藏状态和细胞状态
h0 = torch.zeros(3, batch_size, hidden_size)
c0 = torch.zeros(3, batch_size, hidden_size)
output, (hn, cn) = lstm(x, (h0, c0))
print("多层 LSTM 的隐藏状态形状:", hn.shape) # (num_layers, batch_size, hidden_size)
print("多层 LSTM 的细胞状态形状:", cn.shape)
5.3 双向 LSTM(Bidirectional)
lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, bidirectional=True)
# 初始化双向 LSTM
h0 = torch.zeros(2, batch_size, hidden_size) # 2 = num_directions
c0 = torch.zeros(2, batch_size, hidden_size)
output, (hn, cn) = lstm(x, (h0, c0))
print("双向 LSTM 的输出形状:", output.shape) # (batch_size, seq_len, 2 * hidden_size)
print("双向 LSTM 的隐藏状态形状:", hn.shape) # (num_layers * 2, batch_size, hidden_size)
5.4 带投影的 LSTM(proj_size
> 0)
lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, proj_size=8)
# 初始化隐藏和细胞状态
h0 = torch.zeros(1, batch_size, 8) # proj_size 维度
c0 = torch.zeros(1, batch_size, hidden_size)
output, (hn, cn) = lstm(x, (h0, c0))
print("带投影的 LSTM 的输出形状:", output.shape) # (batch_size, seq_len, proj_size)
print("带投影的隐藏状态形状:", hn.shape) # (num_layers, batch_size, proj_size)
6. 输入数据的格式要求
6.1 默认格式:batch_first=False
- 输入形状:
(seq_len, batch_size, input_size)
- 输出形状:
(seq_len, batch_size, hidden_size)
6.2 batch_first=True
- 输入形状:
(batch_size, seq_len, input_size)
- 输出形状:
(batch_size, seq_len, hidden_size)
7. LSTM 的返回值解析
output, (hn, cn) = lstm(x, (h0, c0))
output
:所有时间步的隐藏状态- 形状:
(batch_size, seq_len, num_directions * hidden_size)
(batch_first=True
) - 包含 LSTM 在 每个时间步的输出。
- 形状:
hn
:最后一个时间步的隐藏状态- 形状:
(num_layers * num_directions, batch_size, hidden_size)
- 用于继续下一个批次的序列。
- 形状:
cn
:最后一个时间步的细胞状态- 形状与
hn
相同。
- 形状与
8. 处理变长序列
使用 pack_padded_sequence()
和 pad_packed_sequence()
处理变长序列:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# 填充后的序列
x = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
lengths = torch.tensor([3, 2]) # 实际长度
# 将填充序列打包
x_packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
# 经过 LSTM
output_packed, (hn, cn) = lstm(x_packed)
# 解包还原
output, _ = pad_packed_sequence(output_packed, batch_first=True)
print(output.shape)
9. 常见应用场景
- 自然语言处理(NLP)
- 机器翻译
- 语音识别
- 文本分类
- 时间序列预测
- 股票价格预测
- 温度/传感器数据建模
- 信号处理
- 语音信号建模
- ECG 信号分类
10. 总结
特性 | 说明 |
---|---|
input_size | 每个时间步的输入维度 |
hidden_size | 隐藏层维度 |
num_layers | 多层 LSTM |
batch_first | 输入形状控制 (batch_size, seq_len, features) |
bidirectional | 支持双向 LSTM |
proj_size | 支持 LSTM 投影,提高效率 |
torch.nn.LSTM
是 序列建模的核心组件之一,理解其输入输出格式和多层、双向、投影等参数对实现高效的 RNN 模型至关重要。