MNIST数据读取分析

本文详细解析了MNIST数据集的读取过程,包括从输入数据到内存的转换,数据预处理,one-hot编码,以及利用next_batch进行数据切分。通过分析read_data_sets方法和DataSet类,探讨了数据加载、验证集划分、数据归一化和随机化等步骤,强调了数据完全加载到内存的特点。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

从input_data.py中获取


Input_data.py只是个过渡,真正是mnist.pyread_data_sets方法


具体分析read_data_sets方法


这里调用了DataSet,imagets/labels都是空,看看


这里重点看fake_data,设置了2个变量


又设了4,应是为了后面的方法准备


有用的方法,就是next_batch.


切回read_data_sets


这里调用了base.maybe_download


from tensorflow.python.platform import gfile

Gfile明显是处理文件下载的.


获取目录


直接将远程数据下载到本地,又调用urlretrieve_with_retry(url, filename=None)方法



这里执行了URL的文件下载,先不管


把临时文件转成正式文件,下载完成.


根据收到的文件名,打开文件.从中提取出train_images,又调用了解开image数据的方法


With...as... 这个语法是用来代替传统的try...finally语法的,防止打开时出错无处理.

magic=_read32(bytestream): 从byte流中获取.动用了


frombuffer: 读取图片数据的方法



如果第一个4位数,不是2051,抛出异常.


接下来,num_images,rows,cols的具体值.这里的bytestream应是指针顺序读取.

buf = bytestream.read(rows * cols * num_images)

data = numpy.frombuffer(buf, dtype=numpy.uint8)

读出真正的数据(先定义buf,再读取,npbuffer处理方式)

data = data.reshape(num_images, rows, cols, 1)

转置成arrays 4

这里,是把文件的内容全都读进内存,没有分batch.


获取LABEL



前面一样,2049判断文件的正确性

num_items: 接下来的值,为数据量

bytestream.read: 读取到buf

numpy.frombuffer: 存到labels

one_hot处理


num_labels = labels_dense.shape[0]

labels_dense1维矩阵,这里获取了矩阵的长度,相当于标签的数量了.

index_offset = numpy.arange(num_labels) * num_classes

index_offset: [0,10,20,30...]

labels_one_hot = numpy.zeros((num_labels, num_classes))

开出全0矩阵

labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1

labels_one_hot1维长度(flat)进行赋1,这样变成相应位置为1,其余为0

abels_dense.ravel():多维转1,用于取得标签的值

 

: one hot 的数据,是用偏移来算的,故而原LABEL是要0,1,2...这样的数据

经过以上处理,LABEL的所有数据,也都到内存中了.


数据切分


判断验证集数据量(validation_size,默认5000)是否小于样本总量train_images


把样本数据的前validation_size,当成验证数据集,之后的当学习集.


生成各自的数据集,又调用DataSet


assert: 下断言,不满足时跳出.

self._num_examples = images.shape[0]

获取样本数量


只处理黑白图片(depth=1)

3维转2(样本量,样本长*),相当于把图片的2维数据变成1维长条型.


把image的数据,int转变成float32

100* 1/255=100/255,类似归一化.


next_batch

获取批次数据


start = self._index_in_epoch

开始序号

self._index_in_epoch += batch_size

增加一个batch

if self._index_in_epoch > self._num_examples:

      # Finished epoch

      self._epochs_completed += 1

完成一轮

# Shuffle the data

perm = numpy.arange(self._num_examples)

numpy.random.shuffle(perm)

perm随机排序

self._images = self._images[perm]

self._labels = self._labels[perm]

获取乱序之后的数据

 

# Start next epoch

start = 0

self._index_in_epoch = batch_size

assert batch_size <= self._num_examples

设置开始与结束的位置

end = self._index_in_epoch

正式设置结束位置

return self._images[start:end], self._labels[start:end]

返回

 

总体思路:

1.如果刚开始(或新的一轮),就把顺序打乱,从头开始,batch获取量.

2.每次都获取一个batch的数据,直到结束,增加一个迭代.

 

这里的前提,是数据全装到内存了.

故而可知,数据的处理,完全是由编程人员自控的.









评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值