代码
import numpy as np
import struct
from scipy.misc import imsave
from PIL import Image
'''
mnist数据格式
'''
def loadImageSet(filename):
# 读取二进制文件
binfile = open(filename, 'rb')
buffers = binfile.read()
# 取前4个整数,返回一个元组
head = struct.unpack_from('>IIII', buffers, 0)
magicNum = head[0]
imgNum = head[1]
width = head[2]
height = head[3]
# 取data数据,返回一个元组
offset = struct.calcsize('>IIII') # 定位到data开始的位置
bits = imgNum * width * height # data一共有60000*28*28个像素值
bitsString = '>' + str(bits) + 'B' # fmt格式:'>47040000B'
imgs = struct.unpack_from(bitsString, buffers, offset)
binfile.close()
# imgs = np.reshape(imgs, [imgNum, width * height]) # reshape为[60000,784]型数组
imgs = np.array(imgs).astype(np.uint8).reshape(imgNum, 1, width, height)
return imgs, head
if __name__=="__main__":
file1 = './mnist/train-images.idx3-ubyte'
imgs, data_head = loadImageSet(file1)
#测试,输出第一张图,没有用for导出全部图像
img = Image.fromarray(imgs[1, 0, 0:28, 0:28])
img.save("mnist.jpg")
mnist.jpg