一. 将图片准换为tfrecords文件
我的任务是图到图的过程
import tensorflow as tf
import cv2
import glob
if __name__ == "__main__":
label_path_list = glob.glob("img/*label*")
ghost_path_list = [i.replace("label", "ghost") for i in label_path_list]
label_imgs = [cv2.imread(i) for i in label_path_list]
ghost_imgs = [cv2.imread(i) for i in ghost_path_list]
tfrecord_file='picture_train.tfrecords'
writer=tf.python_io.TFRecordWriter(tfrecord_file)
for i in range(len(label_imgs)):
features = tf.train.Features(feature={'ghost':tf.train.Feature(bytes_list=tf.train.BytesList(value=[ghost_imgs[i].tobytes()])),
'label':tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_imgs[i].tobytes()]))
})
example=tf.train.Example(features=features)
writer.write(example.SerializeToString())
writer.close()
二. tfrecords转化为图片
与上面的代码一一对应
import tensorflow as tf
import cv2
def load_image():
tf_files = ["picture_train.tfrecords"]
load_w = 440
filename_queue = tf.train.string_input_producer(tf_files)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={'ghost': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string),
})
_ghost = tf.decode_raw(features['ghost'], tf.uint8, name='decode_raw')
_label = tf.decode_raw(features['label'], tf.uint8, name='decode_raw')
_ghost = tf.reshape(_ghost, [load_w, load_w, 3])
_label = tf.reshape(_label, [load_w, load_w, 3])
sess = tf.Session()
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(100):
ghost, label = sess.run([_ghost, _label])
cv2.imwrite(f"img/{i}_ghost.jpg", ghost)
cv2.imwrite(f"img/{i}_label.jpg", label)
if __name__ == '__main__':
load_image()
参考:https://blog.youkuaiyun.com/qq_27825451/article/details/83301811
官网描述