1.RNN简介
循环神经网络(RNN)是一种处理序列数据的强大工具,它能够在内部维护一个状态,从而捕捉时间序列数据中的动态特性。RNN的核心思想是利用序列中前后元素的依赖关系,通过循环连接来传递信息,使得网络能够记忆前面的信息并用于后续的计算。
应用场景:
- 时间序列预测:如气象预报、股票价格预测等
- 自然语言处理:用于语言模型、文本生成、机器翻译等
- 语音识别:将语音信号转换成文本
2.单层RNN

单层RNN网络架构如图所示。输入为序列[x1,x2,……,xn],输出为[h1,h2,……,hn],隐藏层同样也是[h1,h2,……,hn],RNN Cell权重参数均相同,封装成函数,调用伪代码为ht=cell(xt,ht-1),所以RNN核的输出既是输出也是下一个序列的输入。

维度参数说明
- batchSize为每一批元素的数量【多少个x捆在一起作为一个batch】
- seqLen为输入元素的数量【x的数量】
- inputSize为每个输入元素的维度【x的维度】
- hiddenSize为中间隐藏层的维度【h的维度】
- input.shape=(batchSize,inputSize)
- output.shape=(batchSize,hiddenSize)
- dataset.shape=(seqLen,batchSize,inputSize)
参数准备
import torch
batch_size=1
seq_len=3
input_size=4
hidden_size=2
调用PyTorch中的API——RNNCell
cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)
官方文档提供的关于该文档的使用说明

随机/初始化数据
dataset=torch.randn(seq_len,batch_size,input_size)
hidden=torch.zeros(batch_size,hidden_size)
循环训练
for idx,input in enumerate(dataset):
print('='*20,'='*20)
print('Input Size:',input.shape)
hidden=cell(input,hidden)
print('Output Size:',hidden.shape)
print(hidden)
结果说明
输出结果如下所示。注意观察维度。

完整代码
import torch
batch_size=1
seq_len=3
input_size=4
hidden_size=2
cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)
dataset=torch.randn(seq_len,batch_size,input_size)
hidden=torch.zeros(batch_size,hidden_size)
for idx,inp

最低0.47元/天 解锁文章
1636

被折叠的 条评论
为什么被折叠?



