tensorflow中bidirectional_dynamic_rnn简单示例和相关参数

假设已具备循环神经网络基础知识

1 tensorflow文档链接

tensorflow手册链接:https://github.com/tensorflow/docs/tree/r1.13/site/en/api_docs/python/tf

2 RNN类单元

RNN是针对序列数据建模的一类神经网络模型框架。tensorflow提供了封装好的接口可以直接传入参数调用即可。
RNN包括简单RNN、LSTM、GRU等单元,调用方法类似。在需要对数据(变长序列)进行双向信息提取的时候可以调用bidirectional_dynamic_rnn。

3 简单示例

3.1 数据

假设有序列数据['汤姆','正在','追赶','杰瑞'],['杰瑞','躲进','了','草丛'],对这个序列数据进行信心建模:
首先,去停用词:将从文本数据中取出,得到预处理数据:['汤姆','正在','追赶','杰瑞'],['杰瑞','躲进','草丛']
其次,进行word_embedding,将文字转换成计算机可以识别(计算)的数字。每个词(词组)转换成一个向量(假设词向量维度为3):

汤姆:[ 1.86343176, -1.59499821, -0.46139333]
正在:[-0.01371737, -0.77550861,  0.10185807]
追赶:[ 0.9538792 , -1.08308753, -1.42873875]
杰瑞:[-0.90520407, -0.30969364,  0.41886028]
杰瑞:[-0.90520407, -0.30969364,  0.41886028]
躲进:[-0.76595131,  0.90407954, -1.00486262]
草丛:[-1.36344792,  0.97559982, -0.28362745]

两句话是两个序列,序列长度分别为4和3,需要对第二个序列进行填充(0填充),保证两个序列长度一致,这样作为模型输入数据的时候不会出错。
得到的输入数据即为:

input_data = array([[[ 1.86343176, -1.59499821, -0.46139333],
        [-0.01371737, -0.77550861,  0.10185807],
        [ 0.9538792 , -1.08308753, -1.42873875],
        [-0.90520407, -0.30969364,  0.41886028]],
       [[-0.90520407, -0.30969364,  0.41886028],
        [-0.76595131,  0.90407954, -1.00486262],
        [-1.36344792,  0.97559982, -0.28362745],
        [ 0.        ,  0.        ,  0.        ]]])

3.2 构建RNN计算单元

seq_length = [4,3]
import tensorflow as tf
# 构建cell_fw和cell_bw
cell_fw = tf.nn.rnn_cell.BasicLSTMCell(num_units=4)
cell_bw = tf.nn.rnn_cell.BasicLSTMCell(num_units=4)
# 初始化cell,不设置初始化则默认为0初始化
initial_state_fw = cell_fw.zero_state(2,dtype=tf.float64)
initial_state_bw = cell_fw.zero_state(2,dtype=tf.float64)
outputs,states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=input_data,initial_state_fw=initial_state_fw,initial_state_bw=initial_state_bw,sequence_length=seq_length,dtype=tf.float64)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
o1,s1=sess.run([outputs,states])
print(o1)
print(s1)

得到o1和s1结果如下:

# o1
(array([[[ 0.16748314, -0.26319363,  0.04769821, -0.05121024],
         [ 0.13402652, -0.29065524,  0.05903449, -0.03146086],
         [ 0.25099752, -0.24584206,  0.01866094, -0.04430275],
         [ 0.06783996, -0.21646949, -0.00682754,  0.01544742]],
 
        [[-0.04379599,  0.03874637, -0.05141118,  0.05217115],
         [-0.13575499,  0.09677944, -0.20853611,  0.07095939],
         [-0.20246804,  0.18730733, -0.27951439,  0.11607156],
         [ 0.        ,  0.        ,  0.        ,  0.        ]]]),
 array([[[ 2.92630630e-01,  9.11851631e-02, -1.00478860e-01,
          -4.01103255e-02],
         [ 6.06710286e-02, -4.59981474e-02, -1.89261316e-01,
          -1.42948033e-04],
         [-1.00241262e-02, -1.24285009e-02, -1.34683944e-01,
           2.99961826e-02],
         [-4.42878719e-02, -8.51187469e-02, -4.84867945e-02,
           9.52617845e-03]],
 
        [[-2.16223927e-01, -1.83200812e-01, -2.72675180e-01,
           1.24081684e-01],
         [-1.68285764e-01, -1.47022809e-01, -1.67582119e-01,
           2.29620903e-01],
         [-1.23245635e-01, -9.67814977e-02, -1.04795337e-01,
           1.40120730e-01],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
           0.00000000e+00]]]))
# s1
(LSTMStateTuple(c=array([[ 0.14960228, -0.36727263, -0.0156995 ,  0.03011006],
        [-0.58717123,  0.39381341, -0.57109581,  0.16564561]]), h=array([[ 0.06783996, -0.21646949, -0.00682754,  0.01544742],
        [-0.20246804,  0.18730733, -0.27951439,  0.11607156]])),
 LSTMStateTuple(c=array([[ 0.5831851 ,  0.28922052, -0.1995207 , -0.14177551],
        [-0.46907702, -0.38997103, -0.49948155,  0.23605249]]), h=array([[ 0.29263063,  0.09118516, -0.10047886, -0.04011033],
        [-0.21622393, -0.18320081, -0.27267518,  0.12408168]])))

o1是每个时间步的hidden_layer输出,形式为tuple,每个元素对应正向反向结果。o1中正向和反向结果中又各对应每个输入序列的hidden_layer的输出值(序列1长度为4;序列2长度为3)
s1是前向和后向每个序列最终时间步cell状态hidden_layer输出
对比可以发现o1的每个序列最后时间步的值和s1的h值相同。

initial_state_fw和initial_state_bw的可以用其他的RNN模型输出作为cell的初始化。
如使用上面的s1作为初始化,得到的结果将和0初始化不同:

initial_state_fw = s1[0]
initial_state_bw = s1[1]
sess.run(tf.global_variables_initializer())
outputs2,states2 = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=input_data,initial_state_fw=initial_state_fw,initial_state_bw=initial_state_bw,sequence_length=seq_length,dtype=tf.float64)
o2,s2 = sess.run([outputs2,states2])
print(o2)
print(s2)

得到o2和s2结果:

# o2
(array([[[ 0.09264089, -0.08258323,  0.05526498, -0.19484421],
         [ 0.12510858, -0.04264845,  0.06277278, -0.16281255],
         [ 0.05611837,  0.04899326,  0.07337347, -0.25136631],
         [ 0.12133026,  0.01433822,  0.02793275, -0.08281042]],
 
        [[-0.09847787,  0.11218909, -0.17476367,  0.02773898],
         [-0.13967064,  0.09041934, -0.27535409,  0.19656187],
         [-0.07928222,  0.00569383, -0.24421959,  0.19889229],
         [ 0.        ,  0.        ,  0.        ,  0.        ]]]),
 array([[[ 0.15853502,  0.33094248, -0.29151092, -0.21539306],
         [ 0.16352575,  0.18191235, -0.23442651, -0.14024268],
         [ 0.18102019,  0.16179395, -0.21143666, -0.13473829],
         [ 0.16336909,  0.11958743, -0.14033709, -0.05007693]],
 
        [[-0.09515931, -0.22884901, -0.06287023,  0.19503592],
         [-0.07946009, -0.12400517, -0.04894213,  0.30135051],
         [-0.16788661, -0.10867359, -0.06732264,  0.22288605],
         [ 0.        ,  0.        ,  0.        ,  0.        ]]]))
# s2
(LSTMStateTuple(c=array([[ 0.29375743,  0.0233761 ,  0.07331977, -0.22109375],
        [-0.20348281,  0.00997103, -0.46476339,  0.75725753]]), h=array([[ 0.12133026,  0.01433822,  0.02793275, -0.08281042],
        [-0.07928222,  0.00569383, -0.24421959,  0.19889229]])),
 LSTMStateTuple(c=array([[ 0.36527035,  0.51972398, -0.49281985, -0.45606273],
        [-0.22130596, -0.48458956, -0.15229997,  0.49193544]]), h=array([[ 0.15853502,  0.33094248, -0.29151092, -0.21539306],
        [-0.09515931, -0.22884901, -0.06287023,  0.19503592]])))

4 输出结果

o和s得到的结果:o中的每个时间步的hidden_layer输出既可以认为是包含了到目前为止前面每个时间步元素信息的结果;最后的时间步的hidden_layer输出可以认为是包含了整个序列的信息的编码。
对于前向和后向传播得到的hidden_layer之还可以进行拼接,用作下游任务的输入。
比如,对o1的结果进行拼接:

h_concat = tf.concat(o1,2)
print(h_concat)
# 得到拼接结果,从维度3变成维度6
array([[[ 1.67483141e-01, -2.63193627e-01,  4.76982142e-02,
         -5.12102408e-02,  2.92630630e-01,  9.11851631e-02,
         -1.00478860e-01, -4.01103255e-02],
        [ 1.34026516e-01, -2.90655240e-01,  5.90344865e-02,
         -3.14608636e-02,  6.06710286e-02, -4.59981474e-02,
         -1.89261316e-01, -1.42948033e-04],
        [ 2.50997517e-01, -2.45842063e-01,  1.86609356e-02,
         -4.43027535e-02, -1.00241262e-02, -1.24285009e-02,
         -1.34683944e-01,  2.99961826e-02],
        [ 6.78399579e-02, -2.16469487e-01, -6.82754071e-03,
          1.54474225e-02, -4.42878719e-02, -8.51187469e-02,
         -4.84867945e-02,  9.52617845e-03]],
       [[-4.37959874e-02,  3.87463715e-02, -5.14111802e-02,
          5.21711487e-02, -2.16223927e-01, -1.83200812e-01,
         -2.72675180e-01,  1.24081684e-01],
        [-1.35754985e-01,  9.67794414e-02, -2.08536112e-01,
          7.09593910e-02, -1.68285764e-01, -1.47022809e-01,
         -1.67582119e-01,  2.29620903e-01],
        [-2.02468037e-01,  1.87307326e-01, -2.79514388e-01,
          1.16071557e-01, -1.23245635e-01, -9.67814977e-02,
         -1.04795337e-01,  1.40120730e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00]]])

最后

sess.close()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值