Tensorflow使用TFRecord构建自己的数据集并读取

本文详细介绍如何使用TensorFlow的TFRecord格式存储和读取图片数据。首先介绍如何定义属性来存储图片和标签,接着演示如何将图片转换为TFRecord文件,最后展示如何从TFRecord文件中读取数据并用画图工具显示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Tensorflow使用TFRecord构建自己的数据集并读取

参考文章:

 http://blog.youkuaiyun.com/freedom098/article/details/56011858 
还有 优酷上kevin大神的视频

目标:1、将自己的数据集以TFRecord格式存储。

          2、从TFRecord中读取数据,并使用画图工具,以图片形式展现。


以一个图片为例:


一、将图片存储TFRecod

# 生成整数型的属性
def _int64_feature(value):
    if not isinstance(value,list):
        value=[value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

#生成字符串型的属性
def _byte_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
#将图片存储到tfrecord中
def convert_to_tfrecord(images, labels, save_dir, name):#从图片路径读取图片编码成tfrecord
    '''''convert all images and labels to one tfrecord file. 
    Args: 
        images: list of image directories, string type 
        labels: list of labels, int type 
        save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' 
        name: the name of tfrecord file, string type, e.g.: 'train' 
    Return: 
        no return 
    Note: 
        converting needs some time, be patient... 
    ''' 
    filename = (save_dir + name + '.tfrecords')  
    n_samples = len(labels) 
    
    #判断 image的样本数量和label是否相同
    if np.shape(images)[0] != n_samples:  
        raise ValueError('Images size %d does not match label size %d.' %(images.shape[0], n_samples))  
    
    writer = tf.python_io.TFRecordWriter(filename)  
    print('\nTransform start......')  
    for i in range(len(images)):  
        try: 
            image_raw_data = tf.gfile.FastGFile(images[i],'r').read()
            img_data = tf.image.decode_png(image_raw_data)
              
            label = int(labels[i])  
            example = tf.train.Example(features=tf.train.Features(feature={  
                            'label':int64_feature(label),  
                            'image_raw': bytes_feature(image_raw)}))  
            writer.write(example.SerializeToString())  
        except IOError as e:  
            print('Could not read:', images[i])  
            print('error: %s' %e)  
            print('Skip it!\n')  
    writer.close() 

二、读取数据,并绘图
# read the data from tfrecoder
def read_and_decode(tfrecords_file):  
	'''''read and decode tfrecord file, generate (image, label) batches 
	Args: 
	tfrecords_file: the directory of tfrecord file 
	batch_size: number of images in each batch 
	Returns: 
	image: 4D tensor - [batch_size, width, height, channel] 
	label: 1D tensor - [batch_size] 
	'''  
	# make an input queue from the tfrecord file  
	filename_queue = tf.train.string_input_producer([tfrecords_file])  


	reader = tf.TFRecordReader()  
	_, serialized_example = reader.read(filename_queue)  
#解析读入的样例
	img_features = tf.parse_single_example(  
		                        serialized_example,  
		                        features={  
		                               'label': tf.FixedLenFeature([], tf.int64),  
		                               'image_raw': tf.FixedLenFeature([], tf.string),  
		                               })  
#将字符串解析成相应的数组
	image = tf.decode_raw(img_features['image_raw'], tf.uint8)  
#转化成图片的格式
	image = tf.reshape(image, [465, 315,3])
	sess = tf.Session()
	coord = tf.train.Coordinator()
	threads = tf.train.start_queue_runners(sess=sess,coord=coord)
	image , label = sess.run([image,label])
	print image
	plt.imshow(image)
	plt.show()
	
	sess.close()

read_and_decode('/home/tensor/Desktop/tia.tfrecords')


三、结果



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值