[cifar-10]二进制图片文件的注释理解
前面讲到到的[catsVSdogs]猫狗大战代码注释讲解http://blog.youkuaiyun.com/wsljqian/article/details/78091425,将利用TensorFlow的框架搭建了一个比较完备的结构。从图像样本文件的读取、做标签、batch分批处理;再到CNN模型的搭建、定义损失函数、训练函数、准确率判断函数;最后,调取两部分文件,开始实现样本的训练功能,最后输入单张测试图片,检查它的训练成果率,整个神经网络也主要就是这些部分组成。
今天说到的cifar-10数据库:连接地址:https://www.cs.toronto.edu/~kriz/cifar.html
60000张 32X32 彩色图像 10类
50000张训练
10000张测试
现在看一下下载后的CIFAR-100 binary version的二进制数据集的样子
其中data_batch_1\2\3\4\5为训练集的数据,每个里面包含10000个二进制数据样本,每个样本是<1 x label><3072 x pixel>,意思是1个label+32*32*3pixel=1+3072
一个text_batch为测试集的数据。
########################################################################################################
代码部分,进入主题,具体细节参照http://blog.youkuaiyun.com/wsljqian/article/details/78091425,不在赘述
########################################################################################################
一:cifar10_input.py获取样本
import tensorflow as tf
import numpy as np
import os
#%%
def read_cifar10(data_dir,is_train,batch_size,shuffle):
img_width=32
img_height=32
img_depth=3
label_bytes=1 #label_bytes占一位
img_bytes=img_width*img_height*img_depth
with tf.name_scope('input'):
if is_train:
filenames=[os.path.join(data_dir,'data_batch_%d.bin'%ii)
for ii in np.arange(1,6)]#读训练的五个文件
else:
filenames=[os.path.join(data_dir,'test_batch.bin')]
filename_queue=tf.train.string_input_producer(filenames)
reader=tf.FixedLengthRecordReader(label_bytes+img_bytes)#读取队列 1+3072
key,value=reader.read(filename_queue)
record_bytes=tf.decode_raw(value,tf.uint8)
label=tf.slice(record_bytes,[0],[label_bytes])#切出来第一个元素
label=tf.cast(label,tf.int32)
image_raw=tf.slice(record_bytes,[label_bytes],[img_bytes])
image_raw=tf.reshape(image_raw,[img_depth,img_height,img_width])
image=tf.transpose(image_raw,(1,2,0))#convert from D/H/W to H/W/D
image=tf.cast(image,tf.float32)
#image=tf.image.per_image_standardization(image)#标准化0到1
if shuffle:#洗牌,重新调整顺序
images,label_batch=tf.train.shuffle_batch(
[image,label],
batch_size=batch_size,
num_threads=16,#线程=16个
capacity=2000,
min_after_dequeue=1500)#剩下最少个数
else:
images,label_batch=tf.train.batch(
[image,label],
batch_size=batch_size,
num_threads=16,
capacity=2000)
return images,tf.reshape(label_batch,[batch_size])
注释:if语句部分用来选择是训练数据,还是测试数据正确性后面这一部分呢,就是看看前面的读取数据样本有没有成功,不过在训练时候,这部分可以注释掉,这里也拿出来看看,和catsVSdogs的部分差不多
%%TEST
import matplotlib.pyplot as plt
BATCH_SIZE = 2
data_dir = 'D:/Anaconda3/projects/CAFAR10/data/cifar-10-batches-bin/'
image_batch,label_batch = read_cifar10(data_dir,
is_train=True,
batch_size=BATCH_SIZE,
shuffle=True)
with tf.Session() as sess:
i = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
img,label = sess.run([image_batch,label_batch])
#just test one batch
for j in np.arange(BATCH_SIZE):
print('label: %d' %label[j])
plt.imshow(img[j,:,:,:])#4D数据后面全用冒号
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
这里面的BATCH_SIZE = 2如果输出结果显示的图片就是两种,说明数据读取成功。【不多注释了,具体的详解在前面的一个里面讲的比较详细】