TensorFlow本地导入imdb数据集的方法

写在前面

最近想体验一下AI方面的技术,但是在做一些TensorFlow官网上的简单实验时却发现由于网络、环境等问题,里面的示例代码并不能很顺利地跑在我的机器上,主要原因出在从远端自动load_data的时候,总是会报https、ssl等网络相关的问题。之前的Fashion MNIST,通过网上的一些攻略,找到了本地导入的方法,于是在进行下一个imdb(电影评论文本分类)的小实验时,想自己写一个类似功能的函数来本地导入数据集,于是就有了本文,代码和思路供参考。(另外,我发现keras中的imdb这部分代码并不长,于是把另外一个也需要联网的imdb.get_word_index()也顺便一起改成离线版了,反正网络问题要出现就都是一起出现的…)

我的环境

操作系统:CentOS 7
Conda版本:从TUNA下载的Anaconda3-5.3.1-Linux-x86_64.sh
TensorFlow版本:1.15.4
Keras:Tensorflow中自带的

使用方法

  1. 下载imdb相关的两个文件(imdb.npz, imdb_word_index.json)到本地目录(例如/home/user/datasets/TF_imdb)。度盘下载链接 / 提取码:ic54

  2. 复制下方小节的【离线导入代码块】到.py文件中,修改代码中标出的路径,定位到刚才下载的两个文件。(如果缺少相应的包,可以用pip install

  3. 把官网示例代码中的imdb.load_data(num_words=10000)修改成我们自定义的manual_imdb_load_data(num_words=10000);同理,imdb.get_word_index()修改成manual_imdb_get_word_index()

  4. 其他代码可以和官网示例中保持一致,运行测试即可。

离线导入代码块】:由于是从官方代码中改的,有些注释我保留了。用之前记得先把下载的两个文件放在对应路径下。

import tensorflow as tf
from tensorflow import keras
import numpy as np

print(tf.__version__)	# 简单检测一下tensorflow的版本

# 修改imdb的一些函数(因为https有问题)
from tensorflow.python.platform import tf_logging as logging
import json
def manual_imdb_load_data(path='imdb.npz',
              num_words=None,
              skip_top=0,
              maxlen=None,
              seed=113,
              start_char=1,
              oov_char=2,
              index_from
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值