各位观众老爷,大家好,我是诗人啊_,今天和各位分享RNN模型API的使用,一文速通~
(屏幕前的你,帅气低调有内涵,美丽大方很优雅…所以,求个点赞、收藏、关注呗~)
正经标题:PyTorch 的 RNN API 太坑?3 分钟拿捏维度密码,小白也能秒上手!
此文讲解torch.nn.RNN()的使用,RNN模型的执行步骤请看 ↓↓↓↓↓
书接上回——>>RNN模型: 从原理到拿下面试超绝加分!-优快云博客
目录
前言:
在处理序列数据(如文本、语音、时间序列)时,循环神经网络(RNN)是一种非常经典的模型。本文将基于 PyTorch 框架,通过实例代码详细讲解 RNN 的参数定义、输入输出维度规则以及实际使用方法,帮助初学者快速掌握 RNN 的核心要点
一、RNN 基础概念
循环神经网络(RNN)是一类含有循环连接的神经网络,其核心特点是允许信息的持久化。与传统的前馈神经网络不同,RNN 通过隐藏层之间的循环连接来记忆之前的信息,非常适合处理具有时序关系的序列数据。
RNN 的典型应用场景包括:
- 语言模型与文本生成
- 机器翻译
- 语音识别
- 时间序列预测
- 图像描述生成
二、PyTorch 中 RNN 的核心参数与维度规则
在 PyTorch 中,我们可以通过nn.RNN
类快速构建循环神经网络。下面通过实例代码,详细解析 RNN 的参数定义和维度规则。
2.1 RNN 模型定义
首先我们来看如何定义一个基础的 RNN 模型:
import torch
import torch.nn as nn
# 定义RNN模型
rnn = nn.RNN(input_size=3, hidden_size=4, num_layers=1)
这里的三个核心参数含义如下:
input_size
:输入特征的维度,需要与输入数据的特征维度匹配hidden_size
:隐藏层的特征维度,决定了隐藏状态的维度num_layers
:RNN 的层数,默认为 1,多层 RNN 可以增强模型的表达能力
2.2 初始隐藏状态 h0
RNN 需要一个初始隐藏状态来保存历史信息,定义方式如下:
# 定义初始隐藏状态h0
h0 = torch.randn(1, 2, 4) # 形状:(num_layers * num_directions, batch_size, hidden_size)
h0 的维度规则非常重要,需要严格遵守:
- 第一维:
num_layers * num_directions
(层数 × 方向数),单向单层 RNN 为 1(现在几乎都是单层),双向两层 RNN 则为 4 - 第二维:
batch_size
(批次大小),必须与输入数据的 batch_size 一致 - 第三维:
hidden_size
,必须与模型定义的 hidden_size 一致
注意:如果不手动传入 h0,PyTorch 会默认初始化为全 0 的张量
2.3 输入数据的维度规则
RNN 的输入数据需要是序列形式,其维度定义如下:
# 定义输入数据
input_info = torch.randn(3, 2, 3) # 形状:(seq_len, batch_size, input_size)
输入数据的维度规则:
- 第一维:
seq_len
(序列长度),如一句话包含 5 个单词,seq_len 就是 5 - 第二维:
batch_size
(批次大小),一次处理的样本数量 - 第三维:
input_size
,必须与模型定义的 input_size 一致
小技巧:如果设置
batch_first=True
,输入维度会变为(batch_size, seq_len,input_size)
,更符合我们的直观理解,默认情况下为False
2.4 RNN 的输出
调用 RNN 模型后,会返回两个结果:output
和hn
# 前向传播
output, hn = rnn(input_info, h0)
output
:保存了每个时间步的隐藏层输出,形状为(seq_len, batch_size, hidden_size)
(单向)hn
:仅保存最后一个时间步的隐藏状态,形状与h0
相同,即(num_layers * num_directions, batch_size, hidden_size)
三、完整代码示例与输出解析
下面我们运行完整的代码,观察输出结果的形状:
def rnn_for_base():
# 1. 定义RNN模型
rnn = nn.RNN(input_size=3, hidden_size=4, num_layers=1)
# 2. 定义初始隐藏状态h0
h0 = torch.randn(1, 2, 4) # (num_layers * num_directions, batch_size, hidden_size)
# 3. 定义输入数据
input_info = torch.randn(3, 2, 3) # (seq_len, batch_size, input_size)
# 4. 前向传播
output, hn = rnn(input_info, h0)
print('output形状--->', output.shape) # 期望输出:torch.Size([3, 2, 4])
print('hn形状--->', hn.shape) # 期望输出:torch.Size([1, 2, 4])
print('rnn模型结构--->', rnn)
rnn_for_base()
运行结果解析:
output
的形状为(3, 2, 4)
,对应(seq_len, batch_size, hidden_size)
hn
的形状为(1, 2, 4)
,对应(num_layers * num_directions, batch_size, hidden_size)
- 可以看到,
output
保存了整个序列中每个时间步的隐藏状态,而hn
只保存了最后一个时间步的隐藏状态
四、序列长度变化时的处理
RNN 的一大优势是可以处理不同长度的序列,只需保持批次大小和特征维度不变即可:
def dm_rnn_for_sequencelen():
# 1. 定义RNN模型
rnn = nn.RNN(input_size=3, hidden_size=4, num_layers=1)
# 2. 定义初始隐藏状态
h0 = torch.randn(1, 2, 4)
# 3. 定义输入数据(序列长度变为5)
input_info = torch.randn(5, 2, 3) # 仅改变了seq_len,从3变为5
# 4. 前向传播
output, hn = rnn(input_info, h0)
print('output形状--->', output.shape) # 输出:torch.Size([5, 2, 4])
print('hn形状--->', hn.shape) # 输出:torch.Size([1, 2, 4])
dm_rnn_for_sequencelen()
可以看到,当序列长度从 3 变为 5 时,output
的第一个维度也相应变为 5,而hn
的形状保持不变,因为它只关注最后一个时间步的状态。
五、batch_first 参数的使用
默认情况下,PyTorch 的 RNN 输入维度是(seq_len, batch_size, input_size)
,但我们可以通过batch_first=True
参数将其改为(batch_size, seq_len, input_size)
,这在实际应用中更为常用:
def rnn_with_batch_first():
# 定义batch_first=True的RNN模型
rnn = nn.RNN(input_size=3, hidden_size=4, num_layers=1, batch_first=True)
# 输入数据维度变为(batch_size, seq_len, input_size)
input_info = torch.randn(2, 3, 3) # batch_size=2, seq_len=3, input_size=3
# 初始隐藏状态形状不变
h0 = torch.randn(1, 2, 4)
output, hn = rnn(input_info, h0)
print('batch_first=True时output形状--->', output.shape) # 输出:torch.Size([2, 3, 4])
rnn_with_batch_first()
六、完整案例
import torch
import torch.nn as nn
'''
RNN(Recurrent Neural Network):
RNN定义: 循环神经网络,是含有循环的网络,允许信息的持久化。
RNN原理: 通过神经网络模块来存储信息。循环神经网络通过隐藏层之间的循环连接来记忆之前的信息。
RNN应用: 语言模型与序列生成、机器翻译、语音识别、生成图像描述等。
'''
def rnn_for_base():
# todo: 1. Rnn的输入参数:维度规则: (input_size, hidden_size, num_layers)
rnn = nn.RNN(input_size=3, hidden_size=4, num_layers=1)
'''
input_size: 维度关联:无直接维度规则,需与输入数据的特征维度匹配
hidden_size: 维度关联:会体现在隐藏状态(h0、hn )的第三维,以及 output 的第三维(单向时直接等于 hidden_size ,双向时翻倍)
num_layers: 维度关联:会体现在隐藏状态(h0、hn )的第一维(num_layers * num_directions )
'''
# todo: 2. h0(初始隐藏状态), 维度规则:(num_layers * num_directions, batch_size, hidden_size)
h0 = torch.randn(1, 2, 4) # 作用:初始化 RNN 各层、各方向的隐藏状态,若不手动传入,PyTorch 会默认初始化为全 0 张量
'''
第一维:num_layers * num_directions(层数 × 方向数,如 2 层双向 RNN 就是 4)。
第二维 batch_size:必须与 input_info 的 batch_size 完全一致(同批次样本)。
第三维 hidden_size:必须与模型定义的 hidden_size 一致(隐藏层特征维度)。
'''
# todo: 3. input_info(输入数据的维度表示) , 维度规则:(seq_len, batch_size, input_size)
input_info = torch.randn(3, 2, 3)
'''
第一维 seq_len:序列长度(如一句话有 10 个词,就是 10)
第二维 batch_size:批次大小(如一次输入 32 句话)
第三维 input_size:必须与模型定义的 input_size 一致(输入特征维度)
若设置 batch_first=True, 则为---> (batch_size, seq_len, input_size), 默认 batch_first=False)
'''
# todo: 4. 使用rnn返回两个值,output和hn
# output-->保存每个时间步(单词)输出的隐藏层结果, hn只是代表最后一个单词的输出结果
output, hn = rnn(input_info, h0)
print('output--->', output.shape, output)
print('hn--->', hn.shape, hn)
print('rnn模型--->', rnn)
# 输入数据长度发生变化
def dm_rnn_for_sequencelen():
# 1. 定义RNN模型 -- > nn.RNN todo: 维度规则: (input_size, hidden_size, num_layers)
rnn = nn.RNN(input_size=3, hidden_size=4, num_layers=1)
# 2. 定义初始隐藏状态--> h0 todo: 维度规则: (num_layers * num_directions, batch_size, hidden_size)
h0 = torch.randn(1, 2, 4)
# 3. 定义输入数据--> input_info todo: 维度规则: (seq_len, batch_size, input_size)
input_info = torch.randn(3, 2, 3)
# 4. 使用rnn返回两个值,output和hn
output, hn = rnn(input_info, h0)
print('output--->', output.shape)
print('hn--->', hn.shape)
print('rnn模型--->', rnn)
if __name__ == '__main__':
rnn_for_base()
print('-' * 100)
dm_rnn_for_sequencelen()
输出结果:
七、总结
本文通过实例代码详细讲解了 PyTorch 中 RNN 的使用方法,核心要点总结如下:
- RNN 的核心参数:
input_size
、hidden_size
、num_layers
- 三个关键维度的匹配规则:
- 输入数据:
(seq_len, batch_size, input_size)
或(batch_size, seq_len, input_size)
(当batch_first=True
时) - 初始隐藏状态 h0:
(num_layers * num_directions, batch_size, hidden_size)
- 输出 output:
(seq_len, batch_size, hidden_size)
或(batch_size, seq_len, hidden_size)
- 最终隐藏状态 hn:与 h0 形状相同
- 输入数据:
掌握这些维度规则是使用 RNN 的基础,在实际应用中,我们需要确保各个参数的维度相互匹配,才能正确训练和使用 RNN 模型。
希望本文能帮助大家快速理解和上手 PyTorch 中的 RNN,下一篇文章我们将介绍 LSTM 和 GRU 等更复杂的循环神经网络模型
我是诗人啊_程序员,致力于分享人工智能方面的知识,近期 NLP 自然语言处理系列文章发布中,如果感兴趣,来个关注呗~