问题背景描述
mnist本身是tensorflow下最常用也是最简单基础的数据包。
所以,在新安装tensorflow,给tensorflow配gpu版本,或者试验tensorflow的其他没有接触过的操作时经常被拿来作为测试之用。
然而,官方文档里所说的引用mnist数据库的方法:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(‘MNIST_data’, one_hot=True)
在直接运行时会报错:
File “C:\Anaconda3\envs\tensorflow-gpu\lib\site-packages\zmq\utils\jsonapi.py”, line 43, in dumps
s = s.encode(‘utf8’)
UnicodeEncodeError: ‘utf-8’ codec can’t encode character ‘\udcd5’ in position 2416: surrogates not allowed
网络上这个问题相关的资料还是很多的,也有解决方案,但是我并不喜欢:
- 很多回答都是基于Linux,对于像我这样只有win10的,有些方法值得商榷
- 对于刚接触tensorflow的人,或者说python不熟的,需要的是一个简单粗暴的使用mnist数据的方式,在encode问题或者用二进制文件读取mnist上花费学习的经历简直是消磨革命热情
所以,不如直接基于mnist的二进制读取,写个和官方的引用方式差不多的猴版module,岂不美哉?
准备
http://yann.lecun.com/exdb/mnist/去下载那4个mnist数据包,解压后放在当前工作目录的mnist目录下(没有就创建一个)
代码
在工作目录下创建mnist.py,高仿从命名开始,内容如下:
# -*- coding: utf-8 -*-
import numpy as np
import struct
# 读取图片,返回 [样本数,图像宽*图像高]的numpy数组
def read_img(path, filename):
with open(path+filename,'rb') as bitfile:
buffer = bitfile.read()
head = struct.unpack_from('>IIII',buffer,0)
print('load head:', head)
imgNum = head[1]
width = head[2]
hight = head[3]
bits = imgNum*width*hight
bitsString = '>'+str(bits)+'B'
offset = struct.calcsize('>IIII')
imgs = struct.unpack_from(bitsString,buffer,offset)
imgs = np.reshape(imgs,[imgNum,width*hight])
print('load image finished')
return imgs
# 读取真值,返回 [样本数]的numpy数组
def read_label(path, filename):
with open(path+filename,'rb') as bitfile:
buffer = bitfile.read()
head = struct.unpack_from('>II',buffer,0)
print('load head:', head)
labelNum = head[1]
labelString = '>'+str(labelNum)+'B'
offset = struct.calcsize('>II')
imgs = struct.unpack_from(labelString,buffer,offset)
label = np.reshape(imgs,[labelNum,1])
labels = np.zeros([labelNum,10])
for _ in range(labelNum):
labels[_,label[_]] = 1.0
print('load labels finished')
return labels
class train(object):
def __init__(self,path='mnist\\'):
self.images = read_img(path, 'x_train.idx3-ubyte')
self.labels = read_label(path,'y_train.idx1-ubyte')
self.it_img=iter(self.images)
self.it_label=iter(self.labels)
def next_batch(self,batch_size):
try:
while True:
batch_img=[]
batch_label=[]
for _ in range(batch_size):
batch_img.append(next(self.it_img))
batch_label.append(next(self.it_label))
return np.array(batch_img), np.array(batch_label)
except StopIteration:
return StopIteration
class test(object):
def __init__(self,path='mnist\\'):
self.images = read_img(path, 'x_test.idx3-ubyte')
self.labels = read_label(path,'y_test.idx1-ubyte')
self.it_img=iter(self.images)
self.it_label=iter(self.labels)
def next_batch(self,batch_size):
try:
while True:
batch_img=[]
batch_label=[]
for _ in range(batch_size):
batch_img.append(next(self.it_img))
batch_label.append(next(self.it_label))
return np.array(batch_img), np.array(batch_label)
except StopIteration:
return StopIteration
两个类,train 和test,读取不同的数据则创建对应的对象即可
食用方法
In[1]: import mnist
In[2]: train=mnist.train()
load head: (2051, 60000, 28, 28)
load image finished
load head: (2049, 60000)
load labels finished
In[3]: test=mnist.test()
load head: (2051, 10000, 28, 28)
load image finished
load head: (2049, 10000)
load labels finished
In[4]: img,label=train.next_batch(batch_size=5)
img
Out[5]:
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]])
label
Out[6]:
array([[ 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
In[7]: import numpy as np
In[8]: np.shape(img)
Out[8]: (5, 784)
In[9]: np.shape(label)
Out[9]: (5, 10)
基本上可以算开袋即食了
现在只包含了next_batch这一方法,如果之后有需求可以再加。
本文提供了一种在Windows环境下直接加载MNIST数据集的方法,通过自定义Python模块来绕过常见的Unicode编码错误,实现对数据集的高效读取。
2049

被折叠的 条评论
为什么被折叠?



