使用Torch nngraph实现LSTM

本文介绍了循环神经网络(RNN)的基本概念及其面临的梯度消失问题,并详细解析了长短期记忆网络(LSTM)的设计原理与实现方式,展示了如何使用Torch中的nngraph库构建LSTM模块。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

转自:http://blog.youkuaiyun.com/lanran2/article/details/50603861

什么是RNN

RNN:多层反馈RNN(Recurrent neural Network、循环神经网络)神经网络是一种节点定向连接成环的人工神经网络。这种网络的内部状态可以展示动态时序行为。不同于前馈神经网络的是,RNN可以利用它内部的记忆来处理任意时序的输入序列,这让它可以更容易处理如不分段的手写识别、语音识别等。——百度百科

下面我们看看抽象出来的RNN的公式:
ht=θϕ(ht1)+θxxt
yt=θyϕ(ht)
可以发现每次RNN都要使用上一次中间层的输出 ht

传统RNN的缺点—梯度消失问题(Vanishing gradient problem)

我们定义loss function为 E ,那么梯度公式如下:
Eθ=St=1Etθ
Etθ=tk=1Etytythththkhkθ
hthk=ti=k+1hihi1=ti=k+1θTdiag[ϕ(hi1)]
||hihi1||||θT||||diag[ϕ(hi1)]||γθγϕ
hthk(γθγϕ)tk
不断的乘以小于1的数,梯度会越来越小。为了解决这个问题,LSTM应运而生。

LSTM介绍

定义:LSTM(Long-Short Term Memory,LSTM)
是一种时间递归神经网络,论文首次发表于1997年。由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。——百度百科

提到LSTM,总会伴随一张如下所示的图:
这里写图片描述
从图中可以看到除了输入,还有三个部分:1)Input Gate;2)Forget Gate;3)Output Gate
根据上文提到的RNN,我们的输入是 xt ht1 ,而输入是 ht ct (cell state),其中cell state是LSTM的关键所在,它使得LSTM具有记忆功能。下面是关于LSTM的公式:
1)Input Gate:
it=σ(Wxixt+Whiht1+bi)=σ(linearxi(xt)+linearhi(ht1))
其中 σ 指的是sigmoid函数。
2)Forget Gate:决定是否删除或者保留记忆(memory)
ft=σ(Wxfxt+Whfht1+bf)
3)Output Gate:
ot=σ(Wxoxt+Whoht1+bo)
4)Cell update:
gt=tanh(Wxgxt+Whght1+bg)
5)Cell State Update:
ct=ftct1+itgt
6)Final Output of LSTM:
ht=ottanh(ct)
以上是LSTM一个cell涉及的公式,下面讲解为什么LSTM可以解决在RNN中梯度消失的问题
对于任意 k<t ,我们可以计算:
这里写图片描述
由于每一个因子都非常接近1,所以梯度很难衰减,这样就解决了梯度消失的问题。

Torch nngraph

在使用torch编写LSTM之前,我们需要学习torch中的一个工具nngraph,安转nngraph的命令如下:

<code class="hljs cmake has-numbering">luarocks <span class="hljs-keyword">install</span> nngraph</code>

nngraph的详细介绍:https://github.com/torch/nngraph/
nngraph可以方便设计一个神经网络模块。我们先使用nngraph创建一个简单的网络模块:


z=x1+x2linear(x3)
可以看出这个模块的输入一共有三个, x1 x2 x3 ,输出是 z 。下面是实现这个模块的torch代码:

require 'nngraph'
x1=nn.Identity()()
x2=nn.Identity()()
x3=nn.Identity()()
L=nn.CAddTable()({x1,nn.CMulTable()({x2,nn.Linear(20,10)(x3)})})
mlp=nn.gModule({x1,x2,x3},{L})


首先我们定义 x1 x2 x3 ,使用nn.Identity()();然后对于 linear(x3) ,我们使用x4=nn.Linear(20,10)(x3),定义了一个输入层有20个神经元,输出层有10个神经元的线性神经网络;对于 x2linear(x3) ,使用x5=nn.CMulTable()(x2,x4);对于 x1+x2linear(x3) ,我们使用nn.CAddTable()(x1,x5) 实现;最后使用nn.gModule({input},{output})来定义神经网络模块。
我们使用forward方法测试我们的Module是否正确:

h1=torch.Tensor{1,2,3,4,5,6,7,8,9,10}
h2=torch.Tensor(10):fill(1)
h3=torch.Tensor(20):fill(2)
b=mlp:forward({h1,h2,h3})
parameters=mlp:parameters()[1]
bias=mlp:parameters()[2]
result=torch.cmul(h2,(parameters*h3+bias))+h1


首先我们定义三个输入 h1 , h2 h3 ,然后调用模块mpl的forward命令得到输出b,然后我们获取网络权重w和bias分别保存在parameters和bias变量中,计算 z=h1+h2linear(h3) 的结果result=torch.cmul(h2,(parameters*h3+bias))+h1,最后比较result和b是否一致,我们发现计算的结果是一样的,说明我们的模块是正确的。

使用nngraph编写LSTM模块

现在我们使用nngraph写前文所描述的LSTM模块,代码如下:

require 'nngraph'
function lstm(xt,prev_c,prev_h)
function new_input_sum()
local i2h=nn.Linear(400,400)
local h2h=nn.Linear(400,400)
return nn.CAddTable()({i2h(xt),h2h(prev_h)})
end
local input_gate=nn.Sigmoid()(new_input_sum())
local forget_gate=nn.Sigmoid()(new_input_sum())
local output_gate=nn.Sigmoid()(new_input_sum())
local gt=nn.Tanh()(new_input_sum())
local ct=nn.CAddTable()({nn.CMulTable()({forget_gate,prev_c}),nn.CMulTable()({input_gate,gt})})
local ht=nn.CMulTable()({output_gate,nn.Tanh()(ct)})
return ct,ht
end
xt=nn.Identity()()
prev_c=nn.Identity()()
prev_h=nn.Identity()()
lstm=nn.gModule({xt,prev_c,prev_h},{lstm(xt,prev_c,prev_h)})

其中xtprev_h是输入,prev_c是cell state,然后我们按照前文的公式一次计算,最后输出ct(new cell state),ht(输出)。代码的计算顺序与上文完全一致,所以这里就不再一一解释了。


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值