RNN
一、综述
当我们考虑一个问题的时候,并不是从零开始的。例如我们做阅读理解题时,我们分析问题是基于我们对之前阅读内容的一个理解,而不是忘掉前面阅读的内容,从新开始思考。
RNN(Recurrent Neural Network)的原理就是将信息持久化,通过对之前内容的理解,做出判断。
例如:通过之前用户对某家商店的评论情况学习,判断当前用户对该商店的评论属于好评还是差评。
二、 RNN结构
一个简单的循环神经网络由:输入层、隐藏层和输出层组成。
- 网络在t时刻接收到输入 x t x_t xt之后,隐藏层的值是 s t s_t st,输出值是 o t o_t ot。
-
s
t
s_t
st 的值不仅仅取决于
x
t
x_t
xt,还取决于
s
t
−
1
s_{t-1}
st−1。
公式:
s t = t a n h ( U ∗ x t + W ∗ s t − 1 ) s_t = tanh(U*x_t+W*s_{t-1}) st=tanh(U∗xt+W∗st−1)
o t = s o f t m a x ( V ∗ s t ) o_t = softmax(V*s_t) ot=softmax(V∗st)
标准RNN的有以下特点:
- 权值共享,图中的W全是相同的,U和V也一样。
- 每一个输入值都只与它本身的那条路线建立权连接,不会和别的神经元连接。
三、RNN的缺陷(梯度消失与梯度爆炸)
如上图所示,为经典的RNN结构。
RNN前向传导公式为:
s
t
=
g
(
U
x
t
+
W
s
t
−
1
)
s_t = g(Ux_t+Ws_{t-1})
st=g(Uxt+Wst−1)
o
t
=
f
(
V
s
t
)
o_t = f(Vs_t)
ot=f(Vst)
假设我们的时间序列只有三段,
s
0
s_0
s0为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:
当t=1时:
状态: s 1 = g ( U x 1 + W s 0 ) s_1=g(Ux_1+Ws_0) s1=g(Ux1+Ws0)
输出: o 1 = f ( V s 1 ) o_1=f(Vs_1) o1=f(Vs1)
当t=2时:
状态: s 2 = g ( U x 2 + W s 1 ) s_2 = g(Ux_2+Ws_1) s2=g(Ux2+Ws1)
输出: o 2 = f ( V s 2 ) = f ( V g ( U x 2 + W s 1 ) ) = f ( V g ( U x 2 + W g ( U x 1 + W s 0 ) ) ) o_2 = f(Vs_2)=f(Vg(Ux_2+Ws_1))=f(Vg(Ux_2+Wg(Ux_1+Ws_0))) o2=f(Vs2)=f(Vg(Ux2+Ws1))=f(Vg(Ux2+Wg(Ux1+Ws0)))
当t=3时:
状态: s 3 = g ( U x 3 + W s 2 ) s_3=g(Ux_3+Ws_2) s3=g(Ux3+Ws2)
输出: o 3 = f ( V s 3 ) = f ( V g ( U x 3 + W g ( U x 2 + W g ( U x 1 + W s 0 ) ) ) ) o_3=f(Vs_3)=f(Vg(Ux_3+Wg(Ux_2+Wg(Ux_1+Ws_0)))) o3=f(Vs3)=f(Vg(Ux3+Wg(Ux2+Wg(Ux1+Ws0))))
假设在t=3时刻,损失函数为
L
3
=
1
2
(
y
3
−
o
3
)
2
L_3=\frac{1}{2}(y_3-o_3)^2
L3=21(y3−o3)2
我们使用随机梯度下降法训练RNN其实就是对U、V、W求偏导,并不断调整它们以使L尽可能达到最小的过程。
我们只对t3时刻的U、V、W求偏导(其他时刻类似)
∂
L
3
V
=
∂
L
3
∂
o
3
∂
o
3
∂
V
\frac{\partial L_3}{V}=\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial V}
V∂L3=∂o3∂L3∂V∂o3
∂
L
3
U
=
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
U
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
U
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
U
=
∑
k
=
1
3
∂
L
3
∂
o
3
∂
o
3
∂
s
3
(
∏
j
=
k
+
1
3
∂
s
k
∂
s
k
−
1
)
∂
s
k
∂
U
\frac{\partial L_3}{U}=\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial U}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_k}{\partial s_{k-1}})\frac{\partial s_k}{\partial U}
U∂L3=∂o3∂L3∂s3∂o3∂U∂s3+∂o3∂L3∂s3∂o3∂s2∂s3∂U∂s2+∂o3∂L3∂s3∂o3∂s2∂s3∂s1∂s2∂U∂s1=k=1∑3∂o3∂L3∂s3∂o3(j=k+1∏3∂sk−1∂sk)∂U∂sk
∂
L
3
W
=
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
W
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
W
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
W
=
∑
k
=
1
3
∂
L
3
∂
o
3
∂
o
3
∂
s
3
(
∏
j
=
k
+
1
3
∂
s
k
∂
s
k
−
1
)
∂
s
k
∂
W
\frac{\partial L_3}{W}=\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial W}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_k}{\partial s_{k-1}})\frac{\partial s_k}{\partial W}
W∂L3=∂o3∂L3∂s3∂o3∂W∂s3+∂o3∂L3∂s3∂o3∂s2∂s3∂W∂s2+∂o3∂L3∂s3∂o3∂s2∂s3∂s1∂s2∂W∂s1=k=1∑3∂o3∂L3∂s3∂o3(j=k+1∏3∂sk−1∂sk)∂W∂sk
可以看出对于V求偏导并没有长期依赖,但是对于U、W求偏导,会随着时间序列产生长期依赖。
根据上式推到任意时刻对W求偏导的公式为:
∂
L
t
∂
W
=
∑
t
=
1
t
∂
L
t
∂
o
t
∂
o
t
∂
s
t
(
∏
j
=
t
+
1
t
∂
s
t
∂
s
t
−
1
)
∂
s
t
∂
W
\frac{\partial L_t}{\partial W}=\sum_{t=1}^t \frac{\partial L_t}{\partial o_t}\frac{\partial o_t}{\partial s_t}(\prod_{j=t+1}^t \frac{\partial s_t}{\partial s_{t-1}}) \frac{\partial s_t}{\partial W}
∂W∂Lt=t=1∑t∂ot∂Lt∂st∂ot(j=t+1∏t∂st−1∂st)∂W∂st
s
t
=
g
(
U
x
t
+
W
s
t
−
1
)
=
t
a
n
h
(
U
x
t
+
W
s
t
−
1
)
s_t=g(Ux_t +W s_{t-1})=tanh(Ux_t+Ws_{t-1})
st=g(Uxt+Wst−1)=tanh(Uxt+Wst−1)
所以
∏
j
=
t
+
1
t
∂
s
t
∂
s
t
−
1
=
∏
j
=
t
−
1
t
t
a
n
h
′
W
\prod_{j=t+1}^t\frac{\partial s_t}{\partial s_{t-1}}=\prod_{j=t-1}^t tanh'W
∏j=t+1t∂st−1∂st=∏j=t−1ttanh′W
由上图可以看出
t
a
n
h
′
<
=
1
tanh'<=1
tanh′<=1,对于训练过程大部分情况下tanh的导数是小于1的,若W也是一个大于0小于1的值,则当t很大时,
∏
j
=
t
−
1
t
t
a
n
h
′
W
\prod_{j=t-1}^t tanh'W
∏j=t−1ttanh′W就会趋近于0;若W很大,则
∏
j
=
t
−
1
t
t
a
n
h
′
W
\prod_{j=t-1}^t tanh'W
∏j=t−1ttanh′W就会趋近于无穷大,这就是RNN中梯度消失和梯度爆炸的原因。