keras 的 example 文件 babi_rnn.py 解析

该代码的目的和 https://blog.youkuaiyun.com/zhqh100/article/details/105193991 类似

数据集也是同一个数据集,只不过这个是从 qa2_two-supporting-facts_train.txt 中获取的文本,文本量会大一些

第一个示例

1 Mary moved to the bathroom.
2 Sandra journeyed to the bedroom.
3 Mary got the football there.
4 John went to the kitchen.
5 Mary went back to the kitchen.
6 Mary went back to the garden.
7 Where is the football? 	garden	3 6

单词映射为:

{'.': 1, '?': 2, 'Daniel': 3, 'John': 4, 'Mary': 5, 'Sandra': 6, 'Where': 7, 'apple': 8, 'back': 9, 'bathroom': 10, 'bedroom': 11, 'discarded': 12, 'down': 13, 'dropped': 14, 'football': 15, 'garden': 16, 'got': 17, 'grabbed': 18, 'hallway': 19, 'is': 20, 'journeyed': 21, 'kitchen': 22, 'left': 23, 'milk': 24, 'moved': 25, 'office': 26, 'picked': 27, 'put': 28, 'the': 29, 'there': 30, 'to': 31, 'took': 32, 'travelled': 33, 'up': 34, 'went': 35}

上面的材料编码后为:

[ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  5 25 31 29 10  1  6 21 31 29 11  1  5 17
 29 15 30  1  4 35 31 29 22  1  5 35  9 31 29 22  1  5 35  9 31 29 16  1]
[ 7 20 29 15  2]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

这里把ans进行了one-hot编码,所以 loss 用的是 categorical_crossentropy,而 babi_memnn.py 用的是 sparse_categorical_crossentropy,所以不用进行one-hot编码

训练数据shape

x.shape = (1000, 552)
xq.shape = (1000, 5)
y.shape = (1000, 36)

神经网络结构:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 552)          0
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 5)            0
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 552, 50)      1800        input_1[0][0]
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 5, 50)        1800        input_2[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, 100)          60400       embedding_1[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   (None, 100)          60400       embedding_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 200)          0           lstm_1[0][0]
                                                                 lstm_2[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 36)           7236        concatenate_1[0][0]
==================================================================================================
Total params: 131,636
Trainable params: 131,636
Non-trainable params: 0
__________________________________________________________________________________________________

——————————————————————

总目录

keras的example文件解析

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值