假设已具备循环神经网络基础知识
1 tensorflow文档链接
tensorflow手册链接:https://github.com/tensorflow/docs/tree/r1.13/site/en/api_docs/python/tf
2 RNN类单元
RNN是针对序列数据建模的一类神经网络模型框架。tensorflow提供了封装好的接口可以直接传入参数调用即可。
RNN包括简单RNN、LSTM、GRU等单元,调用方法类似。在需要对数据(变长序列)进行双向信息提取的时候可以调用bidirectional_dynamic_rnn。
3 简单示例
3.1 数据
假设有序列数据['汤姆','正在','追赶','杰瑞'],['杰瑞','躲进','了','草丛']
,对这个序列数据进行信心建模:
首先,去停用词:将了
从文本数据中取出,得到预处理数据:['汤姆','正在','追赶','杰瑞'],['杰瑞','躲进','草丛']
其次,进行word_embedding,将文字转换成计算机可以识别(计算)的数字。每个词(词组)转换成一个向量(假设词向量维度为3):
汤姆:[ 1.86343176, -1.59499821, -0.46139333]
正在:[-0.01371737, -0.77550861, 0.10185807]
追赶:[ 0.9538792 , -1.08308753, -1.42873875]
杰瑞:[-0.90520407, -0.30969364, 0.41886028]
杰瑞:[-0.90520407, -0.30969364, 0.41886028]
躲进:[-0.76595131, 0.90407954, -1.00486262]
草丛:[-1.36344792, 0.97559982, -0.28362745]
两句话是两个序列,序列长度分别为4和3,需要对第二个序列进行填充(0填充),保证两个序列长度一致,这样作为模型输入数据的时候不会出错。
得到的输入数据即为:
input_data = array([[[ 1.86343176, -1.59499821, -0.46139333],
[-0.01371737, -0.77550861, 0.10185807],
[ 0.9538792 , -1.08308753, -1.42873875],
[-0.90520407, -0.30969364, 0.41886028]],
[[-0.90520407, -0.30969364, 0.41886028],
[-0.76595131, 0.90407954, -1.00486262],
[-1.36344792, 0.97559982, -0.28362745],
[ 0. , 0. , 0. ]]])
3.2 构建RNN计算单元
seq_length = [4,3]
import tensorflow as tf
# 构建cell_fw和cell_bw
cell_fw = tf.nn.rnn_cell.BasicLSTMCell(num_units=4)
cell_bw = tf.nn.rnn_cell.BasicLSTMCell(num_units=4)
# 初始化cell,不设置初始化则默认为0初始化
initial_state_fw = cell_fw.zero_state(2,dtype=tf.float64)
initial_state_bw = cell_fw.zero_state(2,dtype=tf.float64)
outputs,states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=input_data,initial_state_fw=initial_state_fw,initial_state_bw=initial_state_bw,sequence_length=seq_length,dtype=tf.float64)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
o1,s1=sess.run([outputs,states])
print(o1)
print(s1)
得到o1和s1结果如下:
# o1
(array([[[ 0.16748314, -0.26319363, 0.04769821, -0.05121024],
[ 0.13402652, -0.29065524, 0.05903449, -0.03146086],
[ 0.25099752, -0.24584206, 0.01866094, -0.04430275],
[ 0.06783996, -0.21646949, -0.00682754, 0.01544742]],
[[-0.04379599, 0.03874637, -0.05141118, 0.05217115],
[-0.13575499, 0.09677944, -0.20853611, 0.07095939],
[-0.20246804, 0.18730733, -0.27951439, 0.11607156],
[ 0. , 0. , 0. , 0. ]]]),
array([[[ 2.92630630e-01, 9.11851631e-02, -1.00478860e-01,
-4.01103255e-02],
[ 6.06710286e-02, -4.59981474e-02, -1.89261316e-01,
-1.42948033e-04],
[-1.00241262e-02, -1.24285009e-02, -1.34683944e-01,
2.99961826e-02],
[-4.42878719e-02, -8.51187469e-02, -4.84867945e-02,
9.52617845e-03]],
[[-2.16223927e-01, -1.83200812e-01, -2.72675180e-01,
1.24081684e-01],
[-1.68285764e-01, -1.47022809e-01, -1.67582119e-01,
2.29620903e-01],
[-1.23245635e-01, -9.67814977e-02, -1.04795337e-01,
1.40120730e-01],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]]]))
# s1
(LSTMStateTuple(c=array([[ 0.14960228, -0.36727263, -0.0156995 , 0.03011006],
[-0.58717123, 0.39381341, -0.57109581, 0.16564561]]), h=array([[ 0.06783996, -0.21646949, -0.00682754, 0.01544742],
[-0.20246804, 0.18730733, -0.27951439, 0.11607156]])),
LSTMStateTuple(c=array([[ 0.5831851 , 0.28922052, -0.1995207 , -0.14177551],
[-0.46907702, -0.38997103, -0.49948155, 0.23605249]]), h=array([[ 0.29263063, 0.09118516, -0.10047886, -0.04011033],
[-0.21622393, -0.18320081, -0.27267518, 0.12408168]])))
o1
是每个时间步的hidden_layer输出,形式为tuple
,每个元素对应正向
和反向
结果。o1
中正向和反向结果中又各对应每个输入序列的hidden_layer的输出值(序列1长度为4;序列2长度为3)
s1
是前向和后向每个序列最终时间步
的cell状态
和hidden_layer输出
。
对比可以发现o1
的每个序列最后时间步的值和s1
的h值相同。
initial_state_fw和initial_state_bw的可以用其他的RNN模型输出作为cell的初始化。
如使用上面的s1作为初始化,得到的结果将和0初始化不同:
initial_state_fw = s1[0]
initial_state_bw = s1[1]
sess.run(tf.global_variables_initializer())
outputs2,states2 = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=input_data,initial_state_fw=initial_state_fw,initial_state_bw=initial_state_bw,sequence_length=seq_length,dtype=tf.float64)
o2,s2 = sess.run([outputs2,states2])
print(o2)
print(s2)
得到o2和s2结果:
# o2
(array([[[ 0.09264089, -0.08258323, 0.05526498, -0.19484421],
[ 0.12510858, -0.04264845, 0.06277278, -0.16281255],
[ 0.05611837, 0.04899326, 0.07337347, -0.25136631],
[ 0.12133026, 0.01433822, 0.02793275, -0.08281042]],
[[-0.09847787, 0.11218909, -0.17476367, 0.02773898],
[-0.13967064, 0.09041934, -0.27535409, 0.19656187],
[-0.07928222, 0.00569383, -0.24421959, 0.19889229],
[ 0. , 0. , 0. , 0. ]]]),
array([[[ 0.15853502, 0.33094248, -0.29151092, -0.21539306],
[ 0.16352575, 0.18191235, -0.23442651, -0.14024268],
[ 0.18102019, 0.16179395, -0.21143666, -0.13473829],
[ 0.16336909, 0.11958743, -0.14033709, -0.05007693]],
[[-0.09515931, -0.22884901, -0.06287023, 0.19503592],
[-0.07946009, -0.12400517, -0.04894213, 0.30135051],
[-0.16788661, -0.10867359, -0.06732264, 0.22288605],
[ 0. , 0. , 0. , 0. ]]]))
# s2
(LSTMStateTuple(c=array([[ 0.29375743, 0.0233761 , 0.07331977, -0.22109375],
[-0.20348281, 0.00997103, -0.46476339, 0.75725753]]), h=array([[ 0.12133026, 0.01433822, 0.02793275, -0.08281042],
[-0.07928222, 0.00569383, -0.24421959, 0.19889229]])),
LSTMStateTuple(c=array([[ 0.36527035, 0.51972398, -0.49281985, -0.45606273],
[-0.22130596, -0.48458956, -0.15229997, 0.49193544]]), h=array([[ 0.15853502, 0.33094248, -0.29151092, -0.21539306],
[-0.09515931, -0.22884901, -0.06287023, 0.19503592]])))
4 输出结果
o和s得到的结果:o中的每个时间步的hidden_layer输出既可以认为是包含了到目前为止前面每个时间步元素信息的结果;最后的时间步的hidden_layer输出可以认为是包含了整个序列的信息的编码。
对于前向和后向传播得到的hidden_layer之还可以进行拼接,用作下游任务的输入。
比如,对o1的结果进行拼接:
h_concat = tf.concat(o1,2)
print(h_concat)
# 得到拼接结果,从维度3变成维度6
array([[[ 1.67483141e-01, -2.63193627e-01, 4.76982142e-02,
-5.12102408e-02, 2.92630630e-01, 9.11851631e-02,
-1.00478860e-01, -4.01103255e-02],
[ 1.34026516e-01, -2.90655240e-01, 5.90344865e-02,
-3.14608636e-02, 6.06710286e-02, -4.59981474e-02,
-1.89261316e-01, -1.42948033e-04],
[ 2.50997517e-01, -2.45842063e-01, 1.86609356e-02,
-4.43027535e-02, -1.00241262e-02, -1.24285009e-02,
-1.34683944e-01, 2.99961826e-02],
[ 6.78399579e-02, -2.16469487e-01, -6.82754071e-03,
1.54474225e-02, -4.42878719e-02, -8.51187469e-02,
-4.84867945e-02, 9.52617845e-03]],
[[-4.37959874e-02, 3.87463715e-02, -5.14111802e-02,
5.21711487e-02, -2.16223927e-01, -1.83200812e-01,
-2.72675180e-01, 1.24081684e-01],
[-1.35754985e-01, 9.67794414e-02, -2.08536112e-01,
7.09593910e-02, -1.68285764e-01, -1.47022809e-01,
-1.67582119e-01, 2.29620903e-01],
[-2.02468037e-01, 1.87307326e-01, -2.79514388e-01,
1.16071557e-01, -1.23245635e-01, -9.67814977e-02,
-1.04795337e-01, 1.40120730e-01],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00]]])
最后
sess.close()