一.加载数据集部分
使用的数据是路透社新闻数据集,该数据集包含许多短新闻及其对应的主题,由路透社在 1986 年发布。它是一个简单的、广泛使用的文本分类数据集。它包括 46 个不同的主题,且训练集中每个主题都有至少 10 个样本。下载地址,当然Keras内置也有这个数据集,可以直接调包,但是可能会出现下载不了的情况,错误会是:
Exception: URL fetch failure on https://s3.amazonaws.com/text-datasets/reuters.npz: None -- [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:852)
下载好的数据集可以直接放入你安装的Keras的datasets中,我的参考地址是:C:\Users\Lenovo\Anaconda3\Lib\site-packages\keras\datasets
二.代码解释部分
1. 导包:
from keras.datasets import reuters
from keras.utils.np_utils import to_categorical
from keras import models
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
2. 加载数据获取训练数据以及标签label
(train_data,train_labels),(test_data,test_labels)=reuters.load_data(nb_words=10000)
-
其中nb_words=10000指的是在训练集中我们只保留词频最高的前10000个单词。10000名之后的词汇都会被直接忽略,不出现在train_data和test_data中,这样得到的向量不会太大,便于处理。
-
其中 train_data 和 test_data 这两个变量都是由短新闻组成的列表,每条新闻列表是由单词索引组成的,每个索引表示的是词频前10000 中的第几个单词的位置。其中len(train_data)=8982,len(test_data)=2246。
比如以下为训练集中的第一条训练数据的展示:[1, 4, 1378, 2025, 9, 697, 4622, 111, 8, 25, 109, 29, 3650, 11, 150, 244, 364, 33, 30, 30, 1398, 333, 6, 2, 159, 9, 1084, 363, 13, 2, 71, 9, 2, 71, 117, 4, 225, 78, 206, 10, 9, 1214, 8, 4, 270, 5, 2, 7, 748, 48, 9, 2, 7, 207, 1451, 966, 1864, 793, 97, 133, 336, 7, 4, 493, 98, 273, 104, 284, 25, 39, 338, 22, 905, 220, 3465, 644, 59, 20, 6, 119, 61, 11, 15, 58, 579, 26, 10, 67, 7, 4, 738, 98, 43, 88, 333, 722, 12, 20, 6, 19, 746, 35, 15, 10, 9, 1214, 855, 129, 783, 21, 4, 2280, 244, 364, 51, 16, 299, 452, 16, 515, 4, 99, 29, 5, 4, 364, 281, 48, 10, 9, 1214, 23, 644, 47, 20, 324, 27, 56, 2, 2, 5, 192, 510, 17, 12]
如果想要知道这条新闻具体是什么内容, 我们使用下面的函数,可以把对应数字代表的文字给转写出来:
word_index = reuters.get_word_index() #word_index是一个将单词映射为整数索引的字典 # 键值颠倒,将整数索引映射为单词 reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) decoded_review = ' '.join([reverse_word_index