【TFRecord】-tensorflow

TensorFlow:MNIST转TFRecord及读取
该博客围绕TensorFlow的TFRecord展开,主要介绍了将MNIST数据集转化为TFRecord格式以及读取TFRecord格式的相关内容,还提及了数据处理方面的参考资料。

【TFRecord】-tensorflow

0.难点说明

mnist 数据集中
train 55000
test  5000
valid 10000

1.将MNIST数据集转化为TFRecord格式

import numpy as np
import tensorflow as tf
import os 
from tensorflow.examples.tutorials.mnist import input_data

def _int64_feature(value):
    '''生成整数的属性'''
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    '''生成字符串型的属性'''
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to(mnist,name):
    '''
    将数据填入到tf.train.Example的协议缓存区(protocol buffer)
    将缓存区序列化为一个字符串,通过tf.python_io.TFRecordWriter 写入 TFRcords文件
    '''
	images = mnist.train.images     # shape [55000,784]
	labels = mnist.train.labels
	
	pixels = images.shape[1]  # 训练图像的分辨率,作为Example的属性
	num_examples = mnist.train.num_examples

	filename = os.path.join(name+'.tfrecords')
    #  创建一个writer来写TFRecord文件
	writer = tf.python_io.TFRecordWriter(filename)
	for i in range(num_examples):
		image_raw = images[i].tostring() # 将图像转为字符串
		example = tf.train.Example(features=tf.train.Features(
        	feature={
            'pixels': _int64_feature(pixels),
            'label': _int64_feature(np.argmax(labels[i])),
            'image_raw': _bytes_feature(image_raw)}))
		writer.write(example.SerializeToString())  # 序列化为字符串
	writer.close()

# 主程序入口
def main(argv=None):
    """
    主程序入口
    声明处理MNIST数据集的类,这个类在初始化时会自动下载数据
    """
    mnist = input_data.read_data_sets('MNIST_data/', dtype=tf.uint8, one_hot=True)
    if mnist != None:
        print("------------数据加载完毕----------------")
    
    convert_to(mnist,'train')

if __name__	 == '__main__':
    tf.app.run ()

2.读取TFRecord 中的格式

import tensorflow as tf

reader = tf.TFRecordReader() # 创建一个reader来读取TFRecord文件中的Example
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['train.tfrecords'])
_, serialized_example = reader.read(filename_queue) # 从文件中读出一个Example

features = tf.parse_single_example(
    	serialized_example,
    	features={
        	'image_raw': tf.FixedLenFeature([], tf.string),
        	'pixels': tf.FixedLenFeature([], tf.int64),
        	'label': tf.FixedLenFeature([], tf.int64)
    	})

images = tf.decode_raw(features['image_raw'], tf.uint8) # 将字符串解析成图像对应的像素数组
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

	

sess = tf.Session()
coord = tf.train.Coordinator() # 启动多线程处理输入数据
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取
for i in range(10):
	image, label, pixel = sess.run([images, labels, pixels])
	print(label)

参考

  1. 【TensorFlow】数据处理(将MNIST转为TFRecord)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值