先贴出常规的读写tfrecord文件的代码,按分类好的文件夹读取图片,并以图片所在的文件夹作为其对应的标签。
import os
import tensorflow as tf
from PIL import Image #注意Image,后面会用到
IMAGE_SIZE = 224
#%%
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
#%%
def SaveTF(ImgPath, SaveTFPath, TFname):
writer= tf.python_io.TFRecordWriter(SaveTFPath + TFname + ".tfrecords") #要生成的文件
classes = []
for root, dirs, files in os.walk(ImgPath):
for name in dirs:
classes.append(name) # acquire class name
for index,name in enumerate(classes):
class_path=os.path.join(ImgPath,name)
for img_name in os.listdir(class_path):
img_path=os.path.join(class_path,img_name) #每一个图片的地址
img=Image.open(img_path)
img= img.resize((IMAGE_SIZE,IMAGE_SIZE))
img_raw=img.tobytes()#将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": _int64_feature(index), #label: 1, 2, 3, ....
'img_raw': _bytes_feature(img_raw)
})) #example对象对label和image数据进行封装
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
print("tfrecord file has stored!")
#%%
#read TF data
def ReadTF(path):
filename_queue = tf.train.string_input_producer([path + '.tfrecords'])#生成一个queue队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})#将image数据和label取出来
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [IMAGE_SIZE, IMAGE_SIZE ,3]) #reshape为128*128的3通道图片k
img = tf.cast(img, tf.float32) * (1. / 255) #在流中抛出img张量
label = tf.cast(features['label'], tf.int32) #在流中抛出label张量
return img, label
然而当图片数据是从别的文件中读出的numpy矩阵形式或者需要向tfrecord文件中存储非图像矩阵时,该如何处理呢?
我首先尝试了仍然按上述方式,即将SaveTF函数中的代码改为:
def SaveTF(Images, labels,SaveTFPath, TFname):
writer= tf.python_io.TFRecordWriter(SaveTFPath + TFname + ".tfrecords") #要生成的文件
Images_shape = Images.shape
for i in range(Images_shape[0]):
img = Images[i] #取第i张图像的数据
img= img.resize((IMAGE_SIZE,IMAGE_SIZE))
img_raw=img.tobytes()#将图片转化为二进制格式
label = labels[i] # 取第i张图像的标签
example = tf.train.Example(features=tf.train.Features(feature={
"label": _int64_feature(label), #label: 1, 2, 3, ....
'img_raw': _bytes_feature(img_raw)
})) #example对象对label和image数据进行封装
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
print("tfrecord file has stored!")
看似没什么问题对不对?可是测试的时候就报错了,
'img_raw': _bytes_feature(img_raw)
报错的代码是上面这行,错误具体忘记是什么了,大概的意思就是说数据类型不匹配。
百度了很久,才发现解决方案,如下:(请忽略我的label类型,我是做了一个特殊处理,处理成了矩阵)
def SaveTF(Images,Groundtruth,SaveTFPath, TFname,PIC_NUM):
writer= tf.python_io.TFRecordWriter(SaveTFPath + TFname + str(PIC_NUM)+ ".tfrecords") #要生成的文件
Images_shape = Images.shape
for i in range(Images_shape[0]):
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': _bytes_feature(tf.compat.as_bytes(Images[i].tostring())),
'label': _bytes_feature(tf.compat.as_bytes(Groundtruth[i].tostring()))#label: 1, 2, 3, ....
})) #example对象对label和image数据进行封装
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
print("Tfrecord file has stored!")
关键代码为:
'img_raw': _bytes_feature(tf.compat.as_bytes(Images[i].tostring())),
百度了一下,tf.compat 模块是一个tensorflow的类型转换模块,详细介绍如下:(转自 Tensorflow的基本使用 )
TensorFlow Python / tf.compat
- 模块:tf.compat
- tf.compat.as_bytes
- tf.compat.as_str_any
- tf.compat.as_text
tf.compat
定义在:tensorflow/python/util/compat.py
与 Python 2 和 3 具有兼容性的函数。
转换例程
除了以下功能之外,as_str 将对象转换为 str。
类型
兼容性模块还提供以下类型:
- bytes_or_text_types
- complex_types
- integral_types
- real_types
功能
as_bytes(...):将字节或 unicode 转换为 bytes,使用 UTF-8 编码进行文本处理。
as_str(...):将字节或 unicode 转换为 bytes,使用 UTF-8 编码进行文本处理。
as_str_any(...):转换 str 为 str(value),但 as_str 用于 bytes。
as_text(...):以 unicode 字符串的形式返回给定的参数。
其他成员
bytes_or_text_types
complex_types
integral_types
real_types
追踪溯源,根本问题是numpy矩阵即使用了 tf.tobytes(),还是转换不成bytes,具体原因我也不清楚,如果有大神知道还请赐教。而tf.compat.as_bytes就可以将矩阵转换成bytes类型了,存储就能顺利进行。
另外,在读取数据的时候我还发现了一个之前一直没有注意的问题,读取和存储tfrecord文件的数据类型必须严格一致,比如存入的是int32格式,读取的时候就只能是int32,用其他的类型比如int64会在reshape矩阵的时候说矩阵大小不一致。
如果是在训练过程中,执行sess.run时报错像下面这样:
Caused by op 'shuffle_batch_3', defined at:
File "/home/song/anaconda3/bin/ipython", line 6, in <module>
sys.exit(IPython.start_ipython())
File "/home/song/anaconda3/lib/python3.5/site-packages/IPython/__init__.py", line 119, in start_ipython
return launch_new_instance(argv=argv, **kwargs)
File "/home/song/anaconda3/lib/python3.5/site-packages/traitlets/config/application.py", line 653, in launch_instance
app.start()
.....
OutOfRangeError (see above for traceback): RandomShuffleQueue '_29_shuffle_batch_3/random_shuffle_queue' is closed and has insufficient elements (requested 1, current size 0)
[[Node: shuffle_batch_3 = QueueDequeueManyV2[component_types=[DT_UINT8, DT_INT64], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch_3/random_shuffle_queue, shuffle_batch_3/n)]]
主要原因是图像的读取尺寸不对,如果数据是从tfrecord文件中读取出来的,很大可能是由于读取类型不一致导致尺寸不一致而无法载入batch数据。