写在前面
最近想体验一下AI方面的技术,但是在做一些TensorFlow官网上的简单实验时却发现由于网络、环境等问题,里面的示例代码并不能很顺利地跑在我的机器上,主要原因出在从远端自动load_data
的时候,总是会报https、ssl等网络相关的问题。之前的Fashion MNIST,通过网上的一些攻略,找到了本地导入的方法,于是在进行下一个imdb(电影评论文本分类)的小实验时,想自己写一个类似功能的函数来本地导入数据集,于是就有了本文,代码和思路供参考。(另外,我发现keras中的imdb这部分代码并不长,于是把另外一个也需要联网的imdb.get_word_index()
也顺便一起改成离线版了,反正网络问题要出现就都是一起出现的…)
我的环境
操作系统:CentOS 7
Conda版本:从TUNA下载的Anaconda3-5.3.1-Linux-x86_64.sh
TensorFlow版本:1.15.4
Keras:Tensorflow中自带的
使用方法
-
下载imdb相关的两个文件(imdb.npz, imdb_word_index.json)到本地目录(例如
/home/user/datasets/TF_imdb
)。度盘下载链接 / 提取码:ic54 -
复制下方小节的【离线导入代码块】到.py文件中,修改代码中标出的路径,定位到刚才下载的两个文件。(如果缺少相应的包,可以用pip install)
-
把官网示例代码中的
imdb.load_data(num_words=10000)
修改成我们自定义的manual_imdb_load_data(num_words=10000)
;同理,imdb.get_word_index()
修改成manual_imdb_get_word_index()
-
其他代码可以和官网示例中保持一致,运行测试即可。
【离线导入代码块】:由于是从官方代码中改的,有些注释我保留了。用之前记得先把下载的两个文件放在对应路径下。
import tensorflow as tf
from tensorflow import keras
import numpy as np
print(tf.__version__) # 简单检测一下tensorflow的版本
# 修改imdb的一些函数(因为https有问题)
from tensorflow.python.platform import tf_logging as logging
import json
def manual_imdb_load_data(path='imdb.npz',
num_words=None,
skip_top=0,
maxlen=None,
seed=113,
start_char=1,
oov_char=2,
index_from