RNN

RNN

一、综述

当我们考虑一个问题的时候,并不是从零开始的。例如我们做阅读理解题时,我们分析问题是基于我们对之前阅读内容的一个理解,而不是忘掉前面阅读的内容,从新开始思考。
RNN(Recurrent Neural Network)的原理就是将信息持久化,通过对之前内容的理解,做出判断。
例如:通过之前用户对某家商店的评论情况学习,判断当前用户对该商店的评论属于好评还是差评。

二、 RNN结构

一个简单的循环神经网络由:输入层、隐藏层和输出层组成。

  • 网络在t时刻接收到输入 x t x_t xt之后,隐藏层的值是 s t s_t st,输出值是 o t o_t ot
  • s t s_t st 的值不仅仅取决于 x t x_t xt,还取决于 s t − 1 s_{t-1} st1
    rnn结构
    公式:
    s t = t a n h ( U ∗ x t + W ∗ s t − 1 ) s_t = tanh(U*x_t+W*s_{t-1}) st=tanh(Uxt+Wst1)
    o t = s o f t m a x ( V ∗ s t ) o_t = softmax(V*s_t) ot=softmax(Vst)

标准RNN的有以下特点:

  • 权值共享,图中的W全是相同的,U和V也一样。
  • 每一个输入值都只与它本身的那条路线建立权连接,不会和别的神经元连接。

三、RNN的缺陷(梯度消失与梯度爆炸)

经典的rnn结构
如上图所示,为经典的RNN结构。
RNN前向传导公式为:
s t = g ( U x t + W s t − 1 ) s_t = g(Ux_t+Ws_{t-1}) st=g(Uxt+Wst1)
o t = f ( V s t ) o_t = f(Vs_t) ot=f(Vst)
假设我们的时间序列只有三段, s 0 s_0 s0为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:

当t=1时:
状态: s 1 = g ( U x 1 + W s 0 ) s_1=g(Ux_1+Ws_0) s1=g(Ux1+Ws0)
输出: o 1 = f ( V s 1 ) o_1=f(Vs_1) o1=f(Vs1)
当t=2时:
状态: s 2 = g ( U x 2 + W s 1 ) s_2 = g(Ux_2+Ws_1) s2=g(Ux2+Ws1)
输出: o 2 = f ( V s 2 ) = f ( V g ( U x 2 + W s 1 ) ) = f ( V g ( U x 2 + W g ( U x 1 + W s 0 ) ) ) o_2 = f(Vs_2)=f(Vg(Ux_2+Ws_1))=f(Vg(Ux_2+Wg(Ux_1+Ws_0))) o2=f(Vs2)=f(Vg(Ux2+Ws1))=f(Vg(Ux2+Wg(Ux1+Ws0)))
当t=3时:
状态: s 3 = g ( U x 3 + W s 2 ) s_3=g(Ux_3+Ws_2) s3=g(Ux3+Ws2)
输出: o 3 = f ( V s 3 ) = f ( V g ( U x 3 + W g ( U x 2 + W g ( U x 1 + W s 0 ) ) ) ) o_3=f(Vs_3)=f(Vg(Ux_3+Wg(Ux_2+Wg(Ux_1+Ws_0)))) o3=f(Vs3)=f(Vg(Ux3+Wg(Ux2+Wg(Ux1+Ws0))))

假设在t=3时刻,损失函数为 L 3 = 1 2 ( y 3 − o 3 ) 2 L_3=\frac{1}{2}(y_3-o_3)^2 L3=21(y3o3)2
我们使用随机梯度下降法训练RNN其实就是对U、V、W求偏导,并不断调整它们以使L尽可能达到最小的过程。
我们只对t3时刻的U、V、W求偏导(其他时刻类似)
∂ L 3 V = ∂ L 3 ∂ o 3 ∂ o 3 ∂ V \frac{\partial L_3}{V}=\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial V} VL3=o3L3Vo3
∂ L 3 U = ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ U + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ U + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ U = ∑ k = 1 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ( ∏ j = k + 1 3 ∂ s k ∂ s k − 1 ) ∂ s k ∂ U \frac{\partial L_3}{U}=\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial U}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_k}{\partial s_{k-1}})\frac{\partial s_k}{\partial U} UL3=o3L3s3o3Us3+o3L3s3o3s2s3Us2+o3L3s3o3s2s3s1s2Us1=k=13o3L3s3o3(j=k+13sk1sk)Usk
∂ L 3 W = ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ W + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ W + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ W = ∑ k = 1 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ( ∏ j = k + 1 3 ∂ s k ∂ s k − 1 ) ∂ s k ∂ W \frac{\partial L_3}{W}=\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial W}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_k}{\partial s_{k-1}})\frac{\partial s_k}{\partial W} WL3=o3L3s3o3Ws3+o3L3s3o3s2s3Ws2+o3L3s3o3s2s3s1s2Ws1=k=13o3L3s3o3(j=k+13sk1sk)Wsk
可以看出对于V求偏导并没有长期依赖,但是对于U、W求偏导,会随着时间序列产生长期依赖。
根据上式推到任意时刻对W求偏导的公式为:
∂ L t ∂ W = ∑ t = 1 t ∂ L t ∂ o t ∂ o t ∂ s t ( ∏ j = t + 1 t ∂ s t ∂ s t − 1 ) ∂ s t ∂ W \frac{\partial L_t}{\partial W}=\sum_{t=1}^t \frac{\partial L_t}{\partial o_t}\frac{\partial o_t}{\partial s_t}(\prod_{j=t+1}^t \frac{\partial s_t}{\partial s_{t-1}}) \frac{\partial s_t}{\partial W} WLt=t=1totLtstot(j=t+1tst1st)Wst

s t = g ( U x t + W s t − 1 ) = t a n h ( U x t + W s t − 1 ) s_t=g(Ux_t +W s_{t-1})=tanh(Ux_t+Ws_{t-1}) st=g(Uxt+Wst1)=tanh(Uxt+Wst1)
所以
∏ j = t + 1 t ∂ s t ∂ s t − 1 = ∏ j = t − 1 t t a n h ′ W \prod_{j=t+1}^t\frac{\partial s_t}{\partial s_{t-1}}=\prod_{j=t-1}^t tanh'W j=t+1tst1st=j=t1ttanhW
tanh函数
由上图可以看出 t a n h ′ < = 1 tanh'<=1 tanh<=1,对于训练过程大部分情况下tanh的导数是小于1的,若W也是一个大于0小于1的值,则当t很大时, ∏ j = t − 1 t t a n h ′ W \prod_{j=t-1}^t tanh'W j=t1ttanhW就会趋近于0;若W很大,则 ∏ j = t − 1 t t a n h ′ W \prod_{j=t-1}^t tanh'W j=t1ttanhW就会趋近于无穷大,这就是RNN中梯度消失和梯度爆炸的原因。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值