mnist数据集介绍
样本在官网下载http://yann.lecun.com/exdb/mnist/
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
#用代码看一下数据集的形状
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集,如果根目录有就调用,没有就下载
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
print('训练数据集')
print(mnist.train.images.shape)
print(mnist.train.labels.shape)
print('验证数据集')
print(mnist.validation.images.shape)
print(mnist.validation.labels.shape)
print('测试数据集')
print(mnist.test.images.shape)
print(mnist.test.labels.shape)
#结果如下:
训练数据集
(55000, 784)
(55000, 10)
验证数据集
(5000, 784)
(5000, 10)
测试数据集
(10000, 784)
(10000, 10)
也就是说images集是由28*28像素的图片组成,源文件都是字节文件,我们写个脚本看一下就行。
from tensorflow.examples.tutorials.mnist import input_data
import scipy.misc
import matplotlib.pyplot as plt
import os
# 载入数据集,如果根目录有就调用,没有就下载
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
img=mnist.train.images[0,:].reshape(28,28)
def mnist_plot_img(img):
(rows, cols) = img.shape
plt.figure()
plt.gray()
plt.imshow(img)
plt.show()
mnist_plot_img(img)

labels是onehot(独热码,在英文文献中称做 one-hot code, 直观来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制。)编码的标签,0-9一共十个状态所以他就是一个十维,打印一个看一下,由图可知刚刚打印的数字是七。
print(mnist.train.labels[0,:])

以上就是mnist训练集的内容。
input_data模块:
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=tf.float32):
**dtype:**的作用是将图像像素点的灰度值从[0, 255]转变为[0.0, 1.0]
定义好之后看提取数据模块,之前的分析都是使用的写好的方法,现在看一下这些方法是如何实现的。
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data

结合代码和文件结构很容易看出:代码中_read32()的作用是从文件流中动态读取4位数据并转换为uint32的数据。
image文件的前四位为魔术码(magic number),只有检测到这4位数据的值和2051相等时,才代表这是正确的image文件,才会继续往下读取。接下来继续读取之后的4位,代表着image文件中,所包含的图片的数量(num_images)。再接着读4位,为每一幅图片的行数(rows),再后4位,为每一幅图片的列数(cols)。最后再读接下来的rows * cols * num_images位,即为所有图片的像素值。最后再用reshape方法将读取到的所有像素值装换为[index, rows, cols, depth](depth=1表示一维空间,彩色是在处理的时候是三维空间RGB)的矩阵。这样就将全部的image数据读取了出来。reshape很有意思如果你不知道图片总数量的话第一个值可以用-1代替,意思就是有一个有很多图片像素用?784表示,把这个矩阵,变换成2828的不知道多少个的矩阵
同理标签集也是这样读取的
856

被折叠的 条评论
为什么被折叠?



