TFRecords
1.TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,
它能更好的利用内存,更方便复制和移动
2.为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中
·文件格式:*.tfrecords
·写入文件内容:Example 协议块 (类字典格式)
TFRecords存储
1、建立TFRecord存储器
tf.python_io.TFRecordWriter(path)
写入tfrecords文件
path: TFRecords文件的路径
return:写文件
method
write(record):向文件中写入一个字符串记录
close():关闭文件写入器
2.构造每个样本的Example协议块
tf.train.Example(features=None)
写入tfrecords文件
features:tf.train.Features类型的特征实例
return:example格式协议块
tf.train.Features(feature=None)
构建每个样本的信息键值对
feature:字典数据,key为要保存的名字,
value为tf.train.Feature实例
return:Features类型
tf.train.Feature(**options)
**options:例如
bytes_list=tf.train. BytesList(value=[Bytes])
int64_list=tf.train. Int64List(value=[Value])
tf.train. Int64List(value=[Value])
tf.train. BytesList(value=[Bytes])
tf.train. FloatList(value=[value])
def write_to_tfrecords(self,iamge_batch,label_batch):
"""
将图片的特征值和目标值存进tfrecords
:param iamge_batch: 10张特征值
:param label_batch: 10图片目标值
:return: None
"""
#1.建立一个tfrecords文件存储器
writer=tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
#2.循环的将所有样本写入文件,每张图片样本都要构造example协议
for i in range(10):
#取出第i个图片的特征值和目标值
image=image_batch[i].eval().tostring()
label=label_batch[i].eval()[0]
#构造一个样本的example值
example=tf.train.Example(features=tf.train.Features(feature={
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
#写入单独的样本
writer.write(example.SerializeToString())
#关闭
writer.close()
return None
在会话中开启
# 存进tfrecords文件
print("开始存储")
cf.write_to_tfrecords(image_batch,label_batch)
print("结束存储")
TFRecords读取
同文件阅读器流程,中间需要解析过程
解析TFRecords的example协议内存块
tf.parse_single_example(serialized,features=None,name=None)
解析一个单一的Example原型
serialized:标量字符串Tensor,一个序列化的Example
features:dict字典数据,键为读取的名字,值为FixedLenFeature
return:一个键值对组成的字典,键为读取的名字
tf.FixedLenFeature(shape,dtype)
shape:输入数据的形状,一般不指定,为空列表
dtype:输入数据类型,与存储进文件的类型要一致
类型只能是float32,int64,string
def read_from_tfrecords(self):
#构造文件队列
file_queue=tf.train.string_input_producer([FLAGS.cifar_tfrecords])
#构造文件阅读器,读取内容 key,value= 一个样本的序列化example
reader=tf.TFRecordReader()
key,value=reader.read(file_queue)
#解析example
features=tf.parse_single_example(value,features={
"image":tf.FixedLenFeature([],tf.string),
"label":tf.FixedLenFeature([],tf.int64)
})
# print(features['image'],features['label'])
#解码内容 读取的格式是string需要解码,如果是int64 float32 不需要解码
image=tf.decode_raw(features['image'],tf.uint8)
#固定图片形状方便与批处理
image_reshape=tf.reshape(image,[self.height,self.weight,self.channel])
label=features['label',tf.int32]
#进行批处理
image_batch,label_batch=tf.train.batch([image_reshape,label],batch_size=10,num_threads=1,capacity=10)
return image_batch,label_batch