Pytorch学习笔记之LSTM

本文深入探讨了长短期记忆(LSTM)网络的工作原理及其在PyTorch中的应用。通过对比传统神经网络,阐述了RNN如何解决信息持久化的问题,并详细分析了LSTM的内部计算过程,包括输入、遗忘、细胞和输出门的运作机制。

Pytorch学习笔记之LSTM


看了理解LSTM这篇博文,在这里写写自己对LSTM网络的一些认识!。

  • RNN
  • 网络计算过程

Recurrent Neural Networks

人类并不是每时每刻都从一片空白的大脑开始他们的思考。在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义。我们不会将所有的东西都全部丢弃,然后用空白的大脑进行思考。我们的思想拥有持久性。
传统的神经网络并不能做到这点,看起来也像是一种巨大的弊端。例如,假设你希望对电影中的每个时间点的时间类型进行分类。传统的神经网络应该很难来处理这个问题——使用电影中先前的事件推断后续的事件。
RNN 解决了这个问题。RNN 是包含循环的网络,允许信息的持久化

在这里插入图片描述
这是一个经典的RNN的流程图。


1. LSTM网络

经典的LSTM的流程图:

在这里插入图片描述

相信大家都看过这个图(盗用别人的图)。
再来一段公式,就是下面的,公式来自Pytorch。

h t h_t ht is the hidden state at time t t t , c t c_t ct is the cell state at time t t t , x t x_t xt is the input at time t t t, h ( t − 1 ) h_{(t-1)} h(t1) is the hidden state of the previous layer at time t − 1 t-1 t1 or the initial hidden state at time 0 0 0 , and i t i_t it , f t f_t ft , g t g_t gt , o t o_t ot are the input, forget, cell, and output gates, respectively. σ \sigma σ is the sigmoid function.

2. 内部计算分析

rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

在这里插入图片描述

可以看到参数的大小变成了(4*20,10),是标准RNN的四倍。原因是这里它包括了四个参数矩阵 W i i W_{ii} Wii W i f W_{if} Wif W i g W_{ig} Wig W i o W_{io} Wio,它们每一个都是(20×10),输入的维度大小是(10×1), 这样 i t i_t it , f t f_t ft , g t g_t gt , o t o_t ot 的维度都是(20×1),公式(5)(6)的运算应该是叉积(元素积),这样得到的 c t c_t ct h t h_t ht的维度才能是20。

在这里插入图片描述
如上图所示hn和cn的最后一维都是20。注意这里的LSTM网络是单向,双向的要*2。蟹蟹!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值