第五课 Tensorflow TFRecord读取数据

本文介绍如何使用TensorFlow的TFRecord格式优化机器学习数据处理流程,包括数据的序列化、存储及读取过程。通过将图像数据转换为TFRecord格式,实现了数据的有效管理和高效读取。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

分享朋友的机器学习应用案例:使用机器学习实现财富自由www.abuquant.com

虽然,可以使用常用的类型,但是使用tfrecord更好。

  1. protobuf的格式传输更快
  2. 结构统一。相当于屏蔽了底层的数据结构。
import tensorflow as tf
import numpy as np
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams["figure.figsize"] = (20,10)
train_df = pd.read_csv('train.csv')
display(train_df.head())
labelpixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
010000000000000000000
100000000000000000000
210000000000000000000
340000000000000000000
400000000000000000000

5 rows × 785 columns

label_df = train_df.pop(item='label')
train_values = train_df.values
train_labels = label_df.values

display(type(train_values))
display(train_values.shape)
display(type(train_labels))
display(train_labels.shape)
numpy.ndarray



(42000, 784)



numpy.ndarray



(42000,)

Example protobuf:

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};
# 建立tfrecorder writer
writer = tf.python_io.TFRecordWriter('csv_train.tfrecords')

for i in xrange(train_values.shape[0]):
    image_raw = train_values[i].tostring()

    # build example protobuf
    example = tf.train.Example(
        features=tf.train.Features(feature={
                'image_raw':  tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))
            }))
    writer.write(record=example.SerializeToString())

writer.close()

从TFRecord中读取数据

reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(['csv_train.tfrecords'])

_, serialized_record = reader.read(filename_queue)

features = tf.parse_single_example(serialized_record,
    features={
        ## tf.FixedLenFeature return Tensor
        ## tf.VarLenFeature return SparseTensor
        "image_raw": tf.FixedLenFeature([], tf.string),
        "label": tf.FixedLenFeature([], tf.int64),
    })

images = tf.decode_raw(features['image_raw'], tf.uint8)
labels = tf.cast(features['label'], tf.int32)

with tf.Session() as session:
    session.run(tf.local_variables_initializer())
    session.run(tf.global_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=session, coord=coord)

    for i in xrange(2):
        image, label = session.run([images, labels])

        display(label)
        display(image)
        print '-' * 40
1



array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)


----------------------------------------



0



array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)


----------------------------------------
INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.CancelledError'>, Enqueue operation was cancelled
     [[Node: input_producer_3/input_producer_3_EnqueueMany = QueueEnqueueManyV2[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer_3, input_producer_3/RandomShuffle)]]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值