python tensorflow学习(二) tensorflow数据的生成与读取

本文介绍了TensorFlow中的数据生成与读取,包括队列的创建与使用,CSV文件的创建、读取,以及TFRecords文件的创建与读取。通过示例展示了如何使用FIFOQueue和RandomShuffleQueue,以及如何处理CSV和TFRecords文件中的图像数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tensorflow数据的生成与读取

明天就开始忙起来了,可能不会日更了,但还是希望能坚持把这个系列做完吧。
大纲

  • tensorflow的队列
  • csv文件的创建和读取
  • tensorflow文件的创建和读取

tensorflow的队列

队列的创建

首先,创建一个队列要选定数据的出入类型,例如FIFOQueueRandomShffleQueue方式,前者为先入先出,后者为随即元素出列。
例:

Q = tf.FIFOQueue(3, "float")

该方法创建了一个先入先出的队列,队列数据个数为3,队列中元素类型为float
然后,可以使用enqueue_many() 函数对该队列进行数据填充。

sess = tf.Session()
init = Q.enqueue_many(([0.1, 0.2, 0.3],))
sess.run(init)

要注意的是,tensorflow中的任何操作都要在会话Session中完成,所以要等到会话的==run()==函数完成后,该队列才会真正被填充。
完整例子:

Q = tf.FIFOQueue(3, "float")

sess = tf.Session()
init = Q.enqueue_many(([0.1, 0.2, 0.3],))
init2 = Q.dequeue()
init3 = Q.enqueue(1.)

sess.run(init)
sess.run(init2)
sess.run(init3)

quelen = sess.run(Q.size())
for i in range(quelen):
    print(sess.run(Q.dequeue()))

tensorflow中对队列的操作和python的队列基本相似
运行结果:

0.2
0.3
1.0

CSV文件的创建和读取

CSV(Comma-Separated Values)文件是最常用的文件存储方式,以纯文本形式存储表格数据。
这里采用图片地址和标签来写入CSV文件
图片目录如下:

创建:
import os

path = 'jpg'
filename = os.listdir(path)
strText = ""

with open("train_list.csv", "w") as fid:
    for a in range(len(filename)):
        strText = path+os.sep+filename[a] + "," + filename[a].split("_")[1][:1] + "\n"
        fid.write(strText)
    fid.close()

结果train_list.csv:

jpg\image_1.jpg,1
jpg\image_2.jpg,2
jpg\image_3.jpg,3
读取
image_add_list = []
image_label_list = []
with open("train_list.csv") as fid:
    for image in fid.readlines():
        image_add_list.append(image.strip().split(",")[0])
        image_label_list.append(image.strip().split(",")[1])

def get_image(image_path):
    return tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(image_path),
                                                                      channels=1), dtype=tf.float32)

for i in range(len(image_add_list)):
    print(image_label_list[i], get_image(image_add_list[i]))

其中用到的函数:
tf.read_file:读取图片地址
tf.image.decode_jpeg把读取的图片解码成jpg格式,并把通道数设为1
tf.image.convert_image_dtype把图像转换为张量
运行结果:

1 Tensor("convert_image:0", shape=(?, ?, 1), dtype=float32)
2 Tensor("convert_image_1:0", shape=(?, ?, 1), dtype=float32)
3 Tensor("convert_image_2:0", shape=(?, ?, 1), dtype=float32)

可以看到,该函数将图片和scv文件转换成了标签和张量格式

tensorflow文件的创建和读取

除了典型的CSV文件读取外,tensorflow还有专门的文件存储格式:TFRecords文件。

TFRecords文件的创建
a_data = 0.834
b_data = [17]
c_data = np.array([[0, 1, 2], [3, 4, 5]])
c = c_data.astype(np.uint8)
c_raw = c.tostring()  # 转化成字符串

example = tf.train.Example(
    features=tf.train.Features(
        feature={
            'a': tf.train.Feature(
                float_list=tf.train.FloatList(value=[a_data])  # 方括号表示输入为list
            ),
            'b': tf.train.Feature(
                int64_list=tf.train.Int64List(value=b_data)  # b_data本身就是列表
            ),
            'c': tf.train.Feature(
              bytes_list=tf.train.BytesList(value=[c_raw])  # c_raw被转化为byte格式
            )
        }
    )
)

# 创建TFRecords文件
writer = tf.python_io.TFRecordWriter("trainArray.tfrecords")
for i in range(1):
    # 创建样本example
    # ...
    serialized = example.SerializeToString()
    writer.write(serialized)
writer.close()

上面代码是TFRecords写入文件的经典格式,即对样本的序列化之后进行写操作完成。可以看到TFRecords文件接收三中数据格式,分别为BytesListInt64ListFloatList

TFRecords文件的读取

TFRecords文件的读取则有些麻烦,要将TFRecords中的数据以输入的格式读取出来
代码如下:

# 读取TFRecords文件
filename_queue = tf.train.input_producer(["trainArray.tfrecords"], num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized_example,
    features={
        'a': tf.FixedLenFeature([], tf.float32),
        'b': tf.FixedLenFeature([], tf.int64),
        'c': tf.FixedLenFeature([], tf.string)
    }
)

a = features['a']
b = features['b']
c_raw = features['c']

c = tf.decode_raw(c_raw, tf.uint8)
c = tf.reshape(c, [2, 3])

print(a)
print(b)
print(c)

输出结果:

Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(), dtype=int64)
Tensor("Reshape:0", shape=(2, 3), dtype=uint8)

太晚了,这一部分剩余内容明天再更^ ^

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值