【PyTorch】torch.nn.LSTM 类:长短期记忆网络

torch.nn.LSTM(长短期记忆网络)

LSTMLong Short-Term Memory)是 循环神经网络(RNN)的一种变种,专门设计来解决传统RNN在处理长序列时出现的 梯度消失梯度爆炸 问题。LSTM通过引入 门控机制 来控制信息的流动,从而保持长时依赖关系。

在PyTorch中,torch.nn.LSTM 是一个用于构建和训练LSTM网络的模块。它是 torch.nn 中的一个重要层(Layer),支持多层堆叠、双向LSTM等特性。


1. 语法

torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True,
              batch_first=False, dropout=0, bidirectional=False, proj_size=0)

2. 参数说明

参数说明
input_size每个时间步输入的特征维度
hidden_size隐藏层的特征维度
num_layers堆叠的 LSTM 层数(默认为 1)
bias是否使用偏置项(默认为 True
batch_firstTrue 时输入输出形状为 (batch_size, seq_len, features)
dropout多层 LSTM 之间的 dropout 比例(默认为 0
bidirectionalTrue 创建双向 LSTM(默认为 False
proj_size如果 > 0,则添加投影层,投影到维度 proj_size

3. 返回值

LSTM 返回一个元组:

output, (h_n, c_n) = lstm(input, (h_0, c_0))
  • output:输出张量,形状为
    • (seq_len, batch_size, num_directions * hidden_size)batch_first=False
    • (batch_size, seq_len, num_directions * hidden_size)batch_first=True
    • 如果 bidirectional=Truenum_directions=2,否则为 1
  • h_n:最后一个时间步的隐藏状态,形状为
    ( n u m _ l a y e r s × n u m _ d i r e c t i o n s , b a t c h _ s i z e , h i d d e n _ s i z e ) (num\_layers \times num\_directions, batch\_size, hidden\_size) (num_layers×num_directions,batch_size,hidden_size)
  • c_n:最后一个时间步的细胞状态,形状同 h_n

4. LSTM的工作原理

LSTM通过 门控机制 控制信息的流动和遗忘,解决了传统RNN的梯度消失问题。LSTM有三个主要的门:

  • 遗忘门(Forget Gate):决定保留过去信息的多少。
  • 输入门(Input Gate):决定当前输入信息的影响程度。
  • 输出门(Output Gate):决定当前时刻的输出。

LSTM的更新公式如下:

f t = σ ( W f x t + U f h t − 1 + b f ) (遗忘门) i t = σ ( W i x t + U i h t − 1 + b i ) (输入门) c ~ t = tanh ⁡ ( W c x t + U c h t − 1 + b c ) (候选单元状态) c t = f t ∘ c t − 1 + i t ∘ c ~ t (更新单元状态) o t = σ ( W o x t + U o h t − 1 + b o ) (输出门) h t = o t ∘ tanh ⁡ ( c t ) (输出) \begin{aligned} f_t &= \sigma(W_f x_t + U_f h_{t-1} + b_f) \quad \text{(遗忘门)} \\ i_t &= \sigma(W_i x_t + U_i h_{t-1} + b_i) \quad \text{(输入门)} \\ \tilde{c}_t &= \tanh(W_c x_t + U_c h_{t-1} + b_c) \quad \text{(候选单元状态)} \\ c_t &= f_t \circ c_{t-1} + i_t \circ \tilde{c}_t \quad \text{(更新单元状态)} \\ o_t &= \sigma(W_o x_t + U_o h_{t-1} + b_o) \quad \text{(输出门)} \\ h_t &= o_t \circ \tanh(c_t) \quad \text{(输出)} \end{aligned} ftitc~tctotht=σ(Wfxt+Ufht1+bf)(遗忘门)=σ(Wixt+Uiht1+bi)(输入门)=tanh(Wcxt+Ucht1+bc)(候选单元状态)=ftct1+itc~t(更新单元状态)=σ(Woxt+Uoht1+bo)(输出门)=ottanh(ct)(输出)

其中:

  • f t f_t ft 是遗忘门。
  • i t i_t it 是输入门。
  • c ~ t \tilde{c}_t c~t 是候选单元状态。
  • c t c_t ct 是单元状态。
  • o t o_t ot 是输出门。
  • h t h_t ht 是隐藏状态。

这些门通过学习得出,控制了信息的流动,保证了LSTM可以有效地捕捉长时间依赖。


5. 基本使用

5.1 单层 LSTM

import torch
import torch.nn as nn

# 定义 LSTM 参数
input_size = 10    # 每个时间步的输入维度
hidden_size = 20   # 隐藏层维度
seq_len = 5        # 序列长度
batch_size = 3     # 批次大小

# 创建 LSTM
lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)

# 生成随机输入数据 (batch_size, seq_len, input_size)
x = torch.randn(batch_size, seq_len, input_size)

# 初始化隐藏状态和细胞状态
h0 = torch.zeros(1, batch_size, hidden_size)  # (num_layers * num_directions, batch_size, hidden_size)
c0 = torch.zeros(1, batch_size, hidden_size)

# 前向传播
output, (hn, cn) = lstm(x, (h0, c0))

print("输出形状:", output.shape)  # (batch_size, seq_len, hidden_size)
print("最后的隐藏状态形状:", hn.shape)  # (num_layers, batch_size, hidden_size)
print("最后的细胞状态形状:", cn.shape)  # (num_layers, batch_size, hidden_size)

(h0, c0)如果不提供,默认是0


5.2 堆叠多层 LSTM

lstm = nn.LSTM(input_size, hidden_size, num_layers=3, batch_first=True)

# 初始化多层隐藏状态和细胞状态
h0 = torch.zeros(3, batch_size, hidden_size)
c0 = torch.zeros(3, batch_size, hidden_size)

output, (hn, cn) = lstm(x, (h0, c0))

print("多层 LSTM 的隐藏状态形状:", hn.shape)  # (num_layers, batch_size, hidden_size)
print("多层 LSTM 的细胞状态形状:", cn.shape)

5.3 双向 LSTM(Bidirectional)

lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, bidirectional=True)

# 初始化双向 LSTM
h0 = torch.zeros(2, batch_size, hidden_size)  # 2 = num_directions
c0 = torch.zeros(2, batch_size, hidden_size)

output, (hn, cn) = lstm(x, (h0, c0))

print("双向 LSTM 的输出形状:", output.shape)  # (batch_size, seq_len, 2 * hidden_size)
print("双向 LSTM 的隐藏状态形状:", hn.shape)  # (num_layers * 2, batch_size, hidden_size)

5.4 带投影的 LSTM(proj_size > 0)

lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, proj_size=8)

# 初始化隐藏和细胞状态
h0 = torch.zeros(1, batch_size, 8)  # proj_size 维度
c0 = torch.zeros(1, batch_size, hidden_size)

output, (hn, cn) = lstm(x, (h0, c0))

print("带投影的 LSTM 的输出形状:", output.shape)  # (batch_size, seq_len, proj_size)
print("带投影的隐藏状态形状:", hn.shape)  # (num_layers, batch_size, proj_size)

6. 输入数据的格式要求

6.1 默认格式:batch_first=False

  • 输入形状:(seq_len, batch_size, input_size)
  • 输出形状:(seq_len, batch_size, hidden_size)

6.2 batch_first=True

  • 输入形状:(batch_size, seq_len, input_size)
  • 输出形状:(batch_size, seq_len, hidden_size)

7. LSTM 的返回值解析

output, (hn, cn) = lstm(x, (h0, c0))
  • output:所有时间步的隐藏状态
    • 形状:(batch_size, seq_len, num_directions * hidden_size)batch_first=True
    • 包含 LSTM 在 每个时间步的输出
  • hn:最后一个时间步的隐藏状态
    • 形状:(num_layers * num_directions, batch_size, hidden_size)
    • 用于继续下一个批次的序列。
  • cn:最后一个时间步的细胞状态
    • 形状与 hn 相同。

8. 处理变长序列

使用 pack_padded_sequence()pad_packed_sequence() 处理变长序列:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# 填充后的序列
x = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
lengths = torch.tensor([3, 2])  # 实际长度

# 将填充序列打包
x_packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)

# 经过 LSTM
output_packed, (hn, cn) = lstm(x_packed)

# 解包还原
output, _ = pad_packed_sequence(output_packed, batch_first=True)
print(output.shape)

9. 常见应用场景

  • 自然语言处理(NLP)
    • 机器翻译
    • 语音识别
    • 文本分类
  • 时间序列预测
    • 股票价格预测
    • 温度/传感器数据建模
  • 信号处理
    • 语音信号建模
    • ECG 信号分类

10. 总结

特性说明
input_size每个时间步的输入维度
hidden_size隐藏层维度
num_layers多层 LSTM
batch_first输入形状控制 (batch_size, seq_len, features)
bidirectional支持双向 LSTM
proj_size支持 LSTM 投影,提高效率

torch.nn.LSTM序列建模的核心组件之一,理解其输入输出格式和多层、双向、投影等参数对实现高效的 RNN 模型至关重要。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值