【keras】解决 example 案例中 MNIST 数据集下载不了的问题

前言:

 

keras 源码中下载MNIST的方式是 path = get_file(path, origin='https://s3.amazonaws.com/img-datasets/mnist.npz'),数据源是通过 url = https://s3.amazonaws.com/img-datasets/mnist.npz 进行下载的。访问该 url 地址被墙了,导致 MNIST 相关的案例都卡在数据下载的环节。本文主要提供解决方案,让需要的读者可以跑案例的代码感受一下。

 

本文的贡献主要包括如下:

 

1)提供 mnist_npz 数据集;

2)分析了关于 mnist 几个相关的源代码;

3)提供了一种能够顺利运行 keras 源码中 example 下 mnist 的相关案例;

4)找到了另外几种解决方案,提供了相关的链接。

 

numpy.load(path)

 

numpy.load() 函数起到很重要的作用。它可以读取 .npy .npz 等文件类型,并返回对应的数据类型。

1)如果文件类型是 .pny 则返回一个1维数组。

2)如果文件类型是 .npz 则返回一个类似字典的数据类型,包含 {filename: array} 键值对。如,本例中的键值对如下所示:

 

f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()

 

详情请参考:https://docs.scipy.org/doc/numpy/reference/generated/numpy.load.html

 

 

原始 .\keras\examples\mnist_mlp.py

https://github.com/keras-team/keras/blob/master/examples/mnist_mlp.py

# -*- coding: utf-8 -*-  
'''Trains a simple deep NN on the MNIST dataset.

Gets to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.
'''

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop


batch_size = 128
num_classes = 10
epochs = 20

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x
### 使用 TensorFlow 下载数据集 对于希望利用 TensorFlow 获取特定数据集的情况,一种常见的方式是使用 `tf.keras.utils.get_file` 函数来实现从网络下载指定的数据源并解压至本地路径。此方法适用于任何可以从互联网上公开获取的数据压缩包。 ```python import tensorflow as tf import pathlib data_root_orig = tf.keras.utils.get_file( origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', fname='flower_photos', untar=True ) data_root = pathlib.Path(data_root_orig) print(data_root)[^1] ``` 上述代码片段展示了如何下载花朵照片集合,并将其存储在一个由系统临时分配的位置中。这里使用的参数包括: - `origin`: 数据集的 URL 地址。 - `fname`: 文件名,在下载完成后会以此名称保存于缓存目录下。 - `untar`: 如果设置为 True,则自动解开 tar 归档文件。 除了手动指定链接外,TensorFlow 还内置了一些常用的数据集可以直接调用而无需额外配置。例如 MNIST 或者 Fashion-MNIST 等图像识别领域内的标准测试案例能够很容易地被加载入程序之中[^2]。 另外,为了简化自定义数据集的操作流程,官方还提供了专门用于构建个人化数据管道的支持工具——即 `tensorflow-datasets` 库。它允许开发者更便捷高效地准备训练样本以及验证集等资源[^3]。 最后值得注意的是,某些情况下并不一定非要先将整个数据集下载下来再做进一步操作;部分小型或结构简单的数据集支持在线读取模式,这意味着可以在不占用过多硬盘空间的前提下完成模型的学习过程[^4]。
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值