制作tfrecord文件时,先考虑我想保存图片的什么信息。
直接将图片输入到卷积神经网络中,需要图片的名称、维度和内容,所以制作tfrecord文件时需要将这三个内容记录下来。
第一步、读取文件夹中的图片。因为项目之后要用到opencv,所以我直接使用cv2读取文件。
def read_image(file_path, resize_h=0, resize_w=0, nomalization=False):
"""
读取图片,默认情况下返回numpy.uint8类型的图片
:param file_path:图片所在路径
:param resize_h:重新调整图片的高度,默认情况下不对图片进行reshape
:param resize_w:重新设置图片的宽度,默认情况下不对图片进行reshape
:param nomalization:默认为False,此时不对图片归一化处理;当为True时,对图片像素值归一化处理
:return:默认情况下返回np.unit8类型的图片,归一化时返回[resize_h, resize_w, 3]的float64类型的图片
"""
or_image = cv2.imread(file_path)
if resize_h > 0 and resize_w > 0:
or_image = cv2.resize(or_image, (resize_w, resize_h))
if nomalization:
or_image = or_image / 255.0
return or_image
第二步,对图片构造example信息。如何构造example信息在这里已经记录过了:https://blog.youkuaiyun.com/Mr_None/article/details/106384102
def create_example(img):
"""
为每张图片创建一个example消息,并将其序列化
:param img:需要创建example的图像
:return:序列化后的Example对象
"""
# np.tostring方法将np数组转化成bytes类型,方便写入tfrecord,但这个方法会丢失矩阵的维度信息,所以还需要记录图片的长宽信息
img_content = img.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'img_content': bytes_feature(img_content),
'img_height': int64_feature(img.shape[0]),
'img_width': int64_feature(img.shape[1]),
'img_chanel': int64_feature(img.shape[2])
}))
example_str = example.SerializeToString()
return example_str
第三步,将每张图片的example信息写入record文件,此时需要遍历文件夹下的所有图片,这里我使用os模块,用os.listdir()读取当前文件夹下所有文件名称,返回一个列表,列表元素为每个文件的名称,类型为字符串。
def record(file_path, record_path, resize_h=0, resize_w=0, normalization=False):
"""
将文件夹中的图片做成tfrecord文件
:param file_path: 数据集所在的文件夹路径
:param record_path: tfrecord文件所在路径
:param resize_h: 需要重新设置的图片的高度,默认情况下不进行reshape
:param resize_w: 需要重新设置的图片的宽度,默认情况下不进行reshape
:param normalization: 归一化参数,默认为不归一化
:return: None
"""
img_list = os.listdir(file_path)
tf_writer = tf.python_io.TFRecordWriter(record_path)
for img_name in img_list:
img_path = os.path.join(file_path, img_name)
img = read_image(img_path, resize_h, resize_w, normalization)
example = create_example(img)
tf_writer.write(example)
print('We had record your data to tfrecord file. You can see it in the file named: ' + record_path)
tf_writer.close()
有个额外的问题,当我使用argparse模块,将源图像文件夹的路径和生成的record文件的路径作为可选参数添加到参数列表后,成功运行
import argparse
import tensorflow as tf
import cv2
import os
# 创建一个解析器
parser = argparse.ArgumentParser(description='The source image path and record path')
# # 添加参数
parser.add_argument('--source_path', type=str, default='D:\\CHR\\data\\reconstruction materials\\video1', help='Raw data file path.')
parser.add_argument('--record_path', default='record', help='The path of tfrecord file.')
# parser.add_argument('--record_path', type=str, default='D:\\CHR\\code\\patent_spring\\ImgToRecord\\video2\\record', help='The path of tfrecord file.')
# 解析参数
args = parser.parse_args()
在命令行运行程序,更改record文件路径(也就是不适用上面代码提供的默认值),结果失败,如下:
但是我修改代码,将上面命令行中提供的路径修改成默认的record文件路径,可以运行成功,如下:
import argparse
# 创建一个解析器
parser = argparse.ArgumentParser(description='The source image path and record path')
# # 添加参数
parser.add_argument('--source_path', default='D:\\CHR\\data\\reconstruction materials\\video1', help='Raw data file path.')
# parser.add_argument('--record_path', default='D:\\CHR\\code\\data_build\\record', help='The path of tfrecord file.')
parser.add_argument('--record_path', default='D:\\CHR\\code\\patent_spring\\ImgToRecord\\video2\\record', help='The path of tfrecord file.')
# 解析参数
args = parser.parse_args()
....
record(args.source_path, args.record_path)
还要再学习学习argparse模块的用法