1 分析
- 构造文件队列
- 读取二进制数据并进行解码
- 处理图片数据形状以及数据类型,批处理返回
- 开启会话线程运行
2 代码
- 定义CIFAR类,设置图片相关的属性
class CifarRead(object):
"""
二进制文件的读取,tfrecords存储读取
"""
def __init__(self):
# 定义一些图片的属性
self.height = 32
self.width = 32
self.channel = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.label_bytes + self.image_bytes
-
实现读取数据方法bytes_read(self, file_list)
- 构造文件队列
# 1、构造文件队列 file_queue = tf.train.string_input_producer(file_list)
- tf.FixedLengthRecordReader(bytes)读取
# 2、使用tf.FixedLengthRecordReader(bytes)读取 # 默认必须指定读取一个样本 reader = tf.FixedLengthRecordReader(self.all_bytes) _, value = reader.read(file_queue)
- 进行解码操作
# 3、解码操作 # (?, ) (3073, ) = label(1, ) + feature(3072, ) label_image = tf.decode_raw(value, tf.uint8) # 为了训练方便,一般会把特征值和目标值分开处理 print(label_image)
- 将数据的标签和图片进行分割
# 使用tf.slice进行切片 label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32) image = tf.slice(label_image, [self.label_bytes], [self.image_bytes]) print(label, image)
- 处理数据的形状,并且进行批处理
# 处理类型和图片数据的形状 # 图片形状 # reshape (3072, )----[channel, height, width] # transpose [channel, height, width] --->[height, width, channel] depth_major = tf.reshape(image, [self.channel, self.height, self.width]) print(depth_major) image_reshape = tf.transpose(depth_major, [1, 2, 0]) print(image_reshape) # 4、批处理 image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)