python如何读取tfrecord文件_Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取...

本文详细介绍了如何在Tensorflow中使用slice_input_producer和string_input_producer进行数据批量读取,以及如何打包和读取TFRecord文件。通过实例展示了读取CSV、图片和TFRecord文件的方法,对于理解和应用Tensorflow数据处理非常有帮助。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

单一数据读取方式:

第一种:slice_input_producer()

# 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...]

[images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

第二种:string_input_producer()

# 需要定义文件读取器,然后通过读取器中的 read()方法来获取数据(返回值类型 key,value),再通过 Session.run(value)查看

file_queue = tf.train.string_input_producer(filename, num_epochs=None, shuffle=True)

reader = tf.WholeFileReader() # 定义文件读取器

key, value = reader.read(file_queue) # key:文件名;value:文件中的内容

!!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。

!!!如果它不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量

!!!以上两种方法都可以生成文件名队列。

(随机)批量数据读取方式:

batchsize=2  # 每次读取的样本数量

tf.train.batch(tensors, batch_size=batchsize)

tf.train.shuffle_batch(tensors, batch_size=batchsize, capacity=batchsize*10, min_after_dequeue=batchsize*5) # capacity > min_after_dequeue

!!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners()

TFRecord文件的打包与读取

一、单一数据读取方式

第一种:slice_input_producer()

def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)

案例1:

import tensorflow as tf

images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']

labels = [1, 2, 3, 4]

# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

# 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错

# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True)

data = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

print(type(data)) #

with tf.Session() as sess:

# sess.run(tf.local_variables_initializer())

sess.run(tf.local_variables_initializer())

coord = tf.train.Coordinator() # 线程的协调器

threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器

for i in range(10):

print(sess.run(data))

coord.request_stop()

coord.join(threads)

"""

运行结果:

[b'image2.jpg', 2]

[b'image1.jpg', 1]

[b'image3.jpg', 3]

[b'image4.jpg', 4]

[b'image2.jpg', 2]

[b'image1.jpg', 1]

[b'image3.jpg', 3]

[b'image4.jpg', 4]

[b'image2.jpg', 2]

[b'image3.jpg', 3]

"""

!!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels],

!!!num_epochs设置

第二种:string_input_producer()

def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, s

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值