学习笔记-PyTorch-RNN循环神经网络

1.RNN简介

        循环神经网络(RNN)是一种处理序列数据的强大工具,它能够在内部维护一个状态,从而捕捉时间序列数据中的动态特性。RNN的核心思想是利用序列中前后元素的依赖关系,通过循环连接来传递信息,使得网络能够记忆前面的信息并用于后续的计算。

应用场景:

  • 时间序列预测:如气象预报、股票价格预测等
  • 自然语言处理:用于语言模型、文本生成、机器翻译等
  • 语音识别:将语音信号转换成文本

2.单层RNN

        单层RNN网络架构如图所示。输入为序列[x1,x2,……,xn],输出为[h1,h2,……,hn],隐藏层同样也是[h1,h2,……,hn],RNN Cell权重参数均相同,封装成函数,调用伪代码为ht=cell(xt,ht-1),所以RNN核的输出既是输出也是下一个序列的输入。

维度参数说明

  • batchSize为每一批元素的数量【多少个x捆在一起作为一个batch】
  •  seqLen为输入元素的数量【x的数量】
  • inputSize为每个输入元素的维度【x的维度】
  • hiddenSize为中间隐藏层的维度【h的维度】
  • input.shape=(batchSize,inputSize)
  • output.shape=(batchSize,hiddenSize)       
  • dataset.shape=(seqLen,batchSize,inputSize)

 参数准备

import torch

batch_size=1
seq_len=3
input_size=4
hidden_size=2

调用PyTorch中的API——RNNCell

cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)

官方文档提供的关于该文档的使用说明

随机/初始化数据

dataset=torch.randn(seq_len,batch_size,input_size)
hidden=torch.zeros(batch_size,hidden_size)

循环训练

for idx,input in enumerate(dataset):
    print('='*20,'='*20)
    print('Input Size:',input.shape)
    
    hidden=cell(input,hidden)
    
    print('Output Size:',hidden.shape)
    print(hidden)

结果说明

        输出结果如下所示。注意观察维度。

完整代码

import torch

batch_size=1
seq_len=3
input_size=4
hidden_size=2

cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)

dataset=torch.randn(seq_len,batch_size,input_size)
hidden=torch.zeros(batch_size,hidden_size)

for idx,inp
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值