RNN、LSTM、GRU

  • 近年来循环神经网络在自然语言处理,语音技术,甚至图像方面都有不错的应用。本文主要介绍基础的RNN,RNN所面对的问题,以及RNN的改进版本:LSTM和GRU

RNN(Recurrent Neural Network)

RNN Architecture

  • 我们先放一张RNN的结构图,一般的RNN也遵循这个过程。输入是x1~xt,绿色的方框表示处理单元, h i h_i hi表示的是隐藏单元, y i y_i yi表示的是输出。对于不同的输入 x i , h i x_i,h_i xi,hi,RNN的cell(一个绿色框)都是彼此之间共享参数的。
  • 一般来说RNN的计算过程分成下面的步骤:
    1. 构造数据,形成{x1,x2, …, xt}的sample
    2. x i x_i xi输入给第 i i i个单元,进行计算,分别得到 y i , h i y_i, h_i yi,hi
    3. 重复上述第二步,得到 y 0 , . . . , y n y_0,...,y_n y0,...,yn,计算loss
    4. 反向传播,更新绿色框中的参数
    5. 重复1~4,直到网络收敛
  • 那么绿色框中到底是什么呢?他是怎么做到记录了上一个输入的信息呢?
  • Standard RNN Cell
    • 标准的RNN cell如下图所示,它里面其实就是封装了一层神经网络和一个非线性处理单元。
      standard RNN cell
    • 公式化如下:
      • h i = f ( W h h h i − 1 + W h x x i ) h_i = f(W^{hh}h_{i-1} + W^{hx}x_i) hi=f(Whhhi1+Whxxi),其中 f f f代表非线性激活函数,例如sigmoid(下面会以其举例说明RNN缺点)。
      • y i = s o f t m a x ( W y h i ) y_i = softmax(W^{y}h_i) yi=softmax(Wyhi),其中y是输出。
    • 它是怎么记下过去的信息的呢?是通过隐藏状态 h i h_i hi记下的。我的理解是是因为我们通过BP优化的是它,所以赋予了 h i h_i hi这么个意义,至于怎么证明 h i h_i hi就是过去的信息,还有待探索。
    • 缺点:如果输入sample里面时刻太长的话,可能会导致梯度消失,从而忘记很早时刻的信息。
      • 为了从数学的角度说明上面那一点,我们就先从BP推导起来。
      • 假设 E E E表示损失函数,令 s = W y h , y i = s o f t m a x ( s i ) s=W^{y}h, y_i=softmax(s_i) s=Wyh,yi=softmax(si)
      • ∂ E ∂ W h h = ∑ i = 1 k ∂ E ∂ y ∗ ∂ y ∂ s ∗ ∂ s ∂ h i ∗ ∂ h i ∂ W h h \frac{\partial E}{\partial W^{hh}}=\sum_{i=1}^k{\frac{\partial E}{\partial y} * \frac{\partial y}{\partial s} * \frac{\partial s}{\partial h_i} * \frac{\partial h_i}{\partial W^{hh}}} WhhE=i=1kyEsyhisWhhhi
      • 其中 i i i表示的第i时刻, k k k表示的是一共有 k k k个时刻。
      • 我们知道,在计算第 i i i时刻的梯度的时候,它与 i + 1 − > k i+1->k i+1>k时刻都有关系。并且这种关系表现在梯度上是惩罚的关系。所以我们可以得到下面的等式
      • ∂ s ∂ h i = Π j = i + 1 k ∂ h j ∂ h j − 1 = Π j = i + 1 k f ′ ( h j ) \frac{\partial s}{\partial h_i} = \Pi_{j=i+1}^k{\frac{\partial h_j}{\partial h_{j-1}}}=\Pi_{j=i+1}^k{f'(h_j)} his=Πj=i+1khj1hj=Πj=i+1kf(hj)
      • 正如我们上面所说,f(x) = sigmoid,其导数范围在0~1之间,如果我们有多个小数相乘的话,就会导致梯度为0,从而导致梯度消失。
      • 注意,我们这里的梯度消失只是针对比较靠前的输入来说,说明其输入没有起到合适的作用(梯度为0)。但是对于靠后的输入来说梯度还是存在的。因为观察上面的公式我们就可以得到靠后的梯度j~k连乘的次数少。
      • 至此,我们说了 W h h W_{hh} Whh在long sequence的传播过程中是如何产生梯度消失问题的。注意 W y W_{y} Wy应该是不会有这个问题的。因为它一般只会更新一次(如果我们只用 y k y_k yk去计算loss的话)。同理 W h x W_{hx} Whx也是会存在这个问题的。
    • 如何解决梯度消失问题呢?sigmoid既然梯度为0,那么relu呢?relu可能会导致梯度爆炸问题。因为relu(x) = x,他没有限制x的取值范围。此外relu的导数是一个常数,他不会随着x的变化而变化。sigmoid通过限制输出的大小,从而限制的整个网络的幅度。那么如何结合relu的问题的?可以使用Batch Normalization, 参考这篇博文
    • 请看下面LSTM和GRU的解决方案。

LSTM (Long Short-term Memory)

  • 正如上面说的普通的RNN会导致梯度消失的问题,那么LSTM是如何解决的呢?
  • 我们先放一张LSTM的cell,如下图所示
    LSTM Cell Architecture
    • LSTM Cell里面有如下几个重要的概念(四门一态):
      • forget gate
      • input gate
      • update gate
      • output gate
      • Cell state
    • forget gate:生成一个mask,决定cell state里面哪些信息应该被遗忘,哪些信息应该被保留。forget可以看成是对cell stage的forget。
      • 其是由 h i , x i , s i g m o i d h_i, x_i, sigmoid hi,xi,sigmoid组成,如下图所示
        forget gate
      • 其中f_t就代表forget gate的输出,它表示了我们要选择性的遗忘cell state里面的某些值(对应位置的f_t为0或者是低响应区域)。
      • 从公式的角度来看: f t = W f h h i − 1 + W f x x i f_t = W_{fh}h_{i-1} + W_{fx}x_i ft=Wfhhi1+Wfxxi
    • input gate:决定新的输入中哪些信息应该被加入的cell state中。所以input可以看成是对cell state的输出。
      • 其是由 h i − 1 , x i , s i g m o i d h_{i-1}, x_i, sigmoid hi1,xi,sigmoid组成,可以看成和forget gate结构一样,但是彼此不共享参数。
      • 其结构图如下所示, C i ^ \hat{C_i} Ci^表示一个新的cell state候选值,其和 i i i_{i} ii点乘从而决定哪些信息应该被加入新的cell state中。
        input gate
      • 数学公式表示: i i = s i g m o i d ( W i h h i − 1 + W i x x i ) , C i ^ = t a n h ( W c h h i − 1 + W c x x i ) i_i=sigmoid(W_{ih}h_{i-1} + W_{ix}x_i), \hat{C_i} = tanh(W_{ch}h_{i-1} + W_{cx}x_i) ii=sigmoid(Wihhi1+Wixxi),Ci^=tanh(Wchhi1+Wcxxi)。而这里为什么使用tanh还有待探索。tanh相对于sigmoid是0均值的。
    • update gate:更新Cell state
      • 其是对f和C作点乘,得到过滤掉信息的C,再对其加上因为本次输入需要添加的信息。
      • 结构图如下所示
        update gate
      • 数学公式表示: C i = C i − 1 ∗ f i + i i ∗ C ^ i C_i = C_{i-1} * f_i + i_{i} * \hat{C}_i Ci=Ci1fi+iiC^i,前者表示删去应该遗忘的信息后保存下来的信息,后者表示应该加上去的信息。
    • output gate:生成我们的hidden state
      • 其是由h_{i-1}, x_i 和 cell state的非线性映射进行点积运算得到的。
      • 其网络结构图如下所示:
        output gate
      • 数学表示: h i = s i g m o i d ( W o h h t − 1 + W o x x i ) ∗ t a n h ( C i ) h_i = sigmoid(W_{oh}h_{t-1}+W_{ox}x_i)*tanh(C_i) hi=sigmoid(Wohht1+Woxxi)tanh(Ci)
    • 其是怎么解决在recurrent过程中出现的梯度消失问题呢?
      • 简单来说,在对 W o h , W o x W_{oh},W_{ox} Woh,Wox计算导数的过程中,我们的 W o h , W o x W_{oh}, W_{ox} Woh,Wox计算导数就会有两部分,前者是连城,后者是加分,有一个C在里面,加分从而避免了梯度消失。比如 h i = s i g m o i d ( W o h h i − 1 + W o x x i ) ∗ t a n h ( C i ) = s i g m o i d ( W o h ( s i g m o i d ( W o h h i − 2 + W o x x i − 1 ) ∗ t a n h ( C i − 1 ) ) + W o x x i ) ∗ t a n h ( C i ) h_i=sigmoid(W_{oh}h_{i-1} + W_{ox}x_i)*tanh(C_i) = sigmoid(W_{oh}{(sigmoid(W_{oh}h_{i-2} + W_{ox}x_{i-1})*tanh(C_{i-1}) )} + W_{ox}x_i)*tanh(C_i) hi=sigmoid(Wohhi1+Woxxi)tanh(Ci)=sigmoid(Woh(sigmoid(Wohhi2+Woxxi1)tanh(Ci1))+Woxxi)tanh(Ci)
      • 复杂来讲有待探索。。

GRU (Gated recurrent unite)

  • 我们上面讲了LSTM是如何的结构,接下来我们看一下GRU是怎么样的结构。
  • 相对于LSTM的cell,GRU相对能简单一些。
    • 首先GRU没有cell state的概念,它将信息一直保存在hidden state中。
    • 其次,最后GRU的输出也是由两部分组成,一部分是上一层hidden state保存下来的有用信息(第一部分),一部分是这层新的hidden hidden state应该被加入的信息(两者取并集)(第二部分)。
      gated recurrent unite
    • GRU由update gate,reset gate,current content gate,output gate四部分组成。
    • update gate:决定上一个hideen state中哪些信息应该被保留,有点像LSTM中的forget gate
      • 其结构图如下所示:
        update gate
      • 公式化: z t = W z h h t − 1 + W z x x t z_t = W_{zh}h_{t-1} + W_{zx}x_t zt=Wzhht1+Wzxxt
    • reset gate:决定上一个state 的哪些信息应该被重置。他与update gate不同的是,update gate主要是用在第一部分。而这里的reset gate主要用在生成第二部分。
      • 其网络结构图如下所示:
        reset gate
      • 其网络结构和update gate基本一致,不共享参数,拥有相同结构。
      • 数学公式表达: r t = W r h h t − 1 + W r x x t r_t = W_{rh}h_{t-1} + W_{rx}x_t rt=Wrhht1+Wrxxt
    • current content gate: 主要是生成本cell的state(注意和输出的state不同,更“隐蔽“,有点像LSTM 里面的cell state)。
      • 其结构如下所示:
        current content gate
      • 使用当前的输x_t, 和经过reset gate处理过的上一cell的state的组合得到本cell的state。
      • 公式化如下: h t ′ = t a n h ( W x + r t ∗ h t − 1 ) h'_t = tanh(Wx + r_t * h_{t-1}) ht=tanh(Wx+rtht1)
    • output gate:输出门,将update后的上一个state和本时刻的state相结合。
      • 其网路结构如下所示:
        output gate
      • 注意,我们在这里相当于重用了 z t z_t zt,使用 1 − z t 1-z_t 1zt就表示要强化update后的上一个时刻没有的信息。
      • 公式化表达: h t = z t ∗ h i + ( 1 + z t ) ∗ h i ′ h_t = z_t * h_i + (1+z_t) * h'_i ht=zthi+(1+zt)hi

对比LSTM和GRU

  • 相似点:
    • 他们相比于传统的RNN,他们都引入了新的gate。
    • 在更新memory content的时候,他们都是原有的content+新生成的content的形式。也就是说他们都会create 一个hidden的hidden new memory content,用这个content和previous content相加,得到最后的content。例如GRU: h t = z t ∗ h i + ( 1 + z t ) ∗ h i ′ h_t = z_t * h_i + (1+z_t) * h'_i ht=zthi+(1+zt)hi;LSTM: C i = C i − 1 ∗ f i + i i ∗ C ^ i C_i = C_{i-1} * f_i + i_{i} * \hat{C}_i Ci=Ci1fi+iiC^i
  • 不同点:
    • 在向下一层传递state的时候,LSTM比GRU多了一个control gate。对比起来GRU: h t = z t ∗ h i + ( 1 + z t ) ∗ h i ′ h_t = z_t * h_i + (1+z_t) * h'_i ht=zthi+(1+zt)hi,而LSTM: h i = s i g m o i d ( W o h h t − 1 + W o x x i ) ∗ t a n h ( C i ) h_i = sigmoid(W_{oh}h_{t-1}+W_{ox}x_i)*tanh(C_i) hi=sigmoid(Wohht1+Woxxi)tanh(Ci),前面的sigmoid就是多出来的control gate。体现在LSTM Cell的结构图是就如下所示:
      different1
    • 第二点不同就是在更新state的时候,针对新生成的memory content,LSTM也比GRU多了一个control gate。用来控制哪些元素应该被用来更新。体现在公式上, GRU: h t ′ = t a n h ( W x + r t ∗ h t − 1 ) h'_t = tanh(Wx + r_t * h_{t-1}) ht=tanh(Wx+rtht1),LSTM: C i = C i − 1 ∗ f i + i i ∗ C ^ i C_i = C_{i-1} * f_i + i_{i} * \hat{C}_i Ci=Ci1fi+iiC^i。体现在LSTM Cell的结构图上就如下图所示
      different 2

参考文献

  1. How RNN work
  2. Understanding LSTM
  3. Understanding GRU
  4. Different between GRU and LSTM
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值