读写tfreord文件——矩阵的存储和读取

本文介绍如何使用TensorFlow的TFRecord格式存储和读取图像数据,并解决存储非图像矩阵及类型不匹配等问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

先贴出常规的读写tfrecord文件的代码,按分类好的文件夹读取图片,并以图片所在的文件夹作为其对应的标签。


import os 
import tensorflow as tf 
from PIL import Image  #注意Image,后面会用到

IMAGE_SIZE = 224

#%%

def _int64_feature(value):  
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
  
def _bytes_feature(value):  
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  

    
#%%
def SaveTF(ImgPath, SaveTFPath, TFname):
    
    writer= tf.python_io.TFRecordWriter(SaveTFPath + TFname + ".tfrecords") #要生成的文件
    classes = []
    for root, dirs, files in os.walk(ImgPath):
        for name in dirs:
            classes.append(name)  # acquire class name
            
    for index,name in enumerate(classes):
        class_path=os.path.join(ImgPath,name)
        for img_name in os.listdir(class_path): 
            img_path=os.path.join(class_path,img_name) #每一个图片的地址
            img=Image.open(img_path)
            img= img.resize((IMAGE_SIZE,IMAGE_SIZE))
            img_raw=img.tobytes()#将图片转化为二进制格式
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": _int64_feature(index), #label: 1, 2, 3, ....
                'img_raw': _bytes_feature(img_raw)
            })) #example对象对label和image数据进行封装
    
    
    
            writer.write(example.SerializeToString())  #序列化为字符串
            
        
    writer.close()
    print("tfrecord file has stored!")


#%%
#read TF data
def ReadTF(path): 
    filename_queue = tf.train.string_input_producer([path + '.tfrecords'])#生成一个queue队列
 
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)#返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })#将image数据和label取出来
 
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [IMAGE_SIZE, IMAGE_SIZE ,3])  #reshape为128*128的3通道图片k
    img = tf.cast(img, tf.float32) * (1. / 255) #在流中抛出img张量
    label = tf.cast(features['label'], tf.int32) #在流中抛出label张量
    return img, label

然而当图片数据是从别的文件中读出的numpy矩阵形式或者需要向tfrecord文件中存储非图像矩阵时,该如何处理呢?

我首先尝试了仍然按上述方式,即将SaveTF函数中的代码改为:


def SaveTF(Images, labels,SaveTFPath, TFname):
    
    writer= tf.python_io.TFRecordWriter(SaveTFPath + TFname + ".tfrecords") #要生成的文件
    Images_shape = Images.shape
    for i in range(Images_shape[0]):
            img = Images[i]  #取第i张图像的数据
            img= img.resize((IMAGE_SIZE,IMAGE_SIZE))
            img_raw=img.tobytes()#将图片转化为二进制格式
            label = labels[i] # 取第i张图像的标签
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": _int64_feature(label), #label: 1, 2, 3, ....
                'img_raw': _bytes_feature(img_raw)
            })) #example对象对label和image数据进行封装            
            writer.write(example.SerializeToString())  #序列化为字符串                    
    writer.close()
    print("tfrecord file has stored!")
看似没什么问题对不对?可是测试的时候就报错了,
'img_raw': _bytes_feature(img_raw)

报错的代码是上面这行,错误具体忘记是什么了,大概的意思就是说数据类型不匹配。

百度了很久,才发现解决方案,如下:(请忽略我的label类型,我是做了一个特殊处理,处理成了矩阵)

def SaveTF(Images,Groundtruth,SaveTFPath, TFname,PIC_NUM):
    
    writer= tf.python_io.TFRecordWriter(SaveTFPath + TFname + str(PIC_NUM)+ ".tfrecords") #要生成的文件
    Images_shape = Images.shape    
    for i in range(Images_shape[0]):
      
      example = tf.train.Example(features=tf.train.Features(feature={
                  'img_raw': _bytes_feature(tf.compat.as_bytes(Images[i].tostring())),
                  'label': _bytes_feature(tf.compat.as_bytes(Groundtruth[i].tostring()))#label: 1, 2, 3, ....

              })) #example对象对label和image数据进行封装
      
      writer.write(example.SerializeToString())  #序列化为字符串
    writer.close()
    print("Tfrecord file has stored!")

关键代码为:

 'img_raw': _bytes_feature(tf.compat.as_bytes(Images[i].tostring())),

百度了一下,tf.compat 模块是一个tensorflow的类型转换模块,详细介绍如下:(转自  Tensorflow的基本使用

TensorFlow Python / tf.compat

  • 模块:tf.compat
  • tf.compat.as_bytes
  • tf.compat.as_str_any
  • tf.compat.as_text
tf.compat

定义在:tensorflow/python/util/compat.py

与 Python 2 和 3 具有兼容性的函数。

转换例程

除了以下功能之外,as_str 将对象转换为 str。

类型

兼容性模块还提供以下类型:

  • bytes_or_text_types
  • complex_types
  • integral_types
  • real_types
功能

as_bytes(...):将字节或 unicode 转换为 bytes,使用 UTF-8 编码进行文本处理。

as_str(...):将字节或 unicode 转换为 bytes,使用 UTF-8 编码进行文本处理。

as_str_any(...):转换 str 为 str(value),但 as_str 用于 bytes。

as_text(...):以 unicode 字符串的形式返回给定的参数。

其他成员

bytes_or_text_types

complex_types

integral_types

real_types


追踪溯源,根本问题是numpy矩阵即使用了 tf.tobytes(),还是转换不成bytes,具体原因我也不清楚,如果有大神知道还请赐教。而tf.compat.as_bytes就可以将矩阵转换成bytes类型了,存储就能顺利进行。

另外,在读取数据的时候我还发现了一个之前一直没有注意的问题,读取和存储tfrecord文件的数据类型必须严格一致,比如存入的是int32格式,读取的时候就只能是int32,用其他的类型比如int64会在reshape矩阵的时候说矩阵大小不一致。

如果是在训练过程中,执行sess.run时报错像下面这样:

Caused by op 'shuffle_batch_3', defined at:
  File "/home/song/anaconda3/bin/ipython", line 6, in <module>
    sys.exit(IPython.start_ipython())
  File "/home/song/anaconda3/lib/python3.5/site-packages/IPython/__init__.py", line 119, in start_ipython
    return launch_new_instance(argv=argv, **kwargs)
  File "/home/song/anaconda3/lib/python3.5/site-packages/traitlets/config/application.py", line 653, in launch_instance
    app.start()
    .....
OutOfRangeError (see above for traceback): RandomShuffleQueue '_29_shuffle_batch_3/random_shuffle_queue' is closed and has insufficient elements (requested 1, current size 0)
     [[Node: shuffle_batch_3 = QueueDequeueManyV2[component_types=[DT_UINT8, DT_INT64], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch_3/random_shuffle_queue, shuffle_batch_3/n)]]
主要原因是图像的读取尺寸不对,如果数据是从tfrecord文件中读取出来的,很大可能是由于读取类型不一致导致尺寸不一致而无法载入batch数据。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值