0. 导读
相信很多深度学习的初学者都看着大神教程搭过很多的CNN、RNN、VGG等网络模型,训练模型时用的基本都是mnist、fashion-mnist、cifar10等基准数据集,当面对自己实验室的或者自己网上爬下来的数据集时就会犯难,根本不知道怎么调整算法模型去读取我们的本地数据集。问题不大,这个对于每一个深度学习工作者都是漫漫探索路上的必经之路。以下我将分享自己的学习心得以及完整代码实现,为曾和我一样正犯难的小伙伴提供一些参考。
代码部分已验证无误,并做了完整注释,请放心食用!
关于读取数据,TensorFlow提供了3种读取方法:
Feeding:
placeholder,feed_dict由占位符代替数据,运行时在线填入数据;Reader:
从文件中直接读取,在一个计算图(tf.graph)的开始前,将文件读入队列(queue)中;Preloaded data:
预加载数据;
鉴于TensorFlow提供了标准的TFRecord格式,接下来我将介绍就是上述的第2种方法。利用tf.record标准口来读入文件。本程序主要包含以下3个核心部分:
制作TFrecord
读取TFrecord数据获得image和label
打印验证并保存生成的图片
1. 准备数据
先在网上下载不同类的图片集,例如几个品种的狗的图片。实验室数据集不方便公开,暂时简单使用网上下载的Dog的图片做介绍。此处已预先下载哈奇士、吉娃娃两种狗的照片各30张,如下:
2. 代码拆分解释
2.1 制作TFrecord
#将原始图片转换成需要的大小,并将其保存
#=========================================================================================
import os
import tensorflow as tf
from PIL import Image
#原始图片的储存位置
orig_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/train'
#生成图片的储存位置
gen_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/Re_train/image_data/inputdata'
#需要的识别类型
classes = {'husky', 'jiwawa'}
#样本总数
num_samples = 60
#制作TFRecords数据
def create_record():
writer = tf.python_io.TFRecordWriter(gen_picture + "/dogs_train.tfrecords")
for index, name in enumerate(classes):
class_path = orig_picture + "/" + name + "/"
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((64,64)) #设置需要转换的图片大小
img_raw = img.tobytes() #将图片转化为原生bytes
print(index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
以下代码段为单独测试部分,结合上述代码块即可制作狗的训练集TFRecord数据,命名为dogs_train.tfrecords
,如生成文件如下图所示:
if __name__ == '__main__':
create_record()
print("Finished")
2.2 读取TFRecord数据获得image 和 label
def read_and_decode(filename):
#创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename])
#create a reader from file queue
reader = tf.TFRecordReader()
#reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
#get feature from serialized example
#解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)})
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [64,64,3])
#img = tf.cast(img, tf.float32) * (1. /255) -0.5
label = tf.cast(label, tf.int32)
return img, label
2.3 打印验证并保存生成的图片
if __name__ == '__main__':
# create_record()
batch = read_and_decode('dogs_train.tfrecords')
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess: #开启一个会话
sess.run(init_op)
coord = tf.train.Coordinator() #多线程管理器
threads = tf.train.start_queue_runners(coord=coord)
for i in range(num_samples):
example, lab = sess.run(batch) #在会话中取出image 和label
img = Image.fromarray(example, 'RGB') #这里Image是之前提到的 Image.fromarray()实现array 到image的转换
img.save(gen_picture + '/' + str(i) + '_Label_' + str(lab) + '.jpg') #保存图片;注意cwd后面加上‘/’
coord.request_stop()
coord.join(threads)
sess.close()
3. 完整代码实现
"""
**利用TensorFlow训练自己的图片数据(1)——预处理
@author: <Colynn Johnson>
@direct: https://blog.youkuaiyun.com/ywx1832990/article/details/78609323
@date: 2020-8-20
"""
"""
首先,我们需要准备训练的原始数据,本次训练为图像分类识别,因而一开始,笔者从网上随机的下载了Dog的四种类别:
husky,jiwawa。每种类别30种,一共60张图片。在训练之前,需要做的就是进行图像的预处理,
即将这些大小不一的原始图片转换成我们训练需要的shape。
编程实现包括:制作TFrecord,读取TFrecord数据获得image和label,打印验证并保存生成的图片
"""
#将原始图片转换成需要的大小,并将其保存
#=========================================================================================
import os
import tensorflow as tf
from PIL import Image
#原始图片的储存位置
orig_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/train'
#生成图片的储存位置
gen_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/Re_train/image_data/inputdata'
#需要的识别类型
classes = {'husky', 'jiwawa'}
#样本总数
"""待定!!"""
num_samples = 60
#制作TFRecords数据
def create_record():
writer = tf.python_io.TFRecordWriter(gen_picture + "/dogs_train.tfrecords")
for index, name in enumerate(classes):
class_path = orig_picture + "/" + name + "/"
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((64,64)) #设置需要转换的图片大小
img_raw = img.tobytes() #将图片转化为原生bytes
print(index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
# if __name__ == '__main__':
# create_record()
# print("Finished")
#=================================================================================================================
"""
读取TFRecord数据获得image 和 label
"""
def read_and_decode(filename):
#创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename])
#create a reader from file queue
reader = tf.TFRecordReader()
#reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
#get feature from serialized example
#解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)})
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [64,64,3])
#img = tf.cast(img, tf.float32) * (1. /255) -0.5
label = tf.cast(label, tf.int32)
return img, label
#========================================================================================
"""
打印验证并保存生成的图片
"""
if __name__ == '__main__':
create_record()
batch = read_and_decode('dogs_train.tfrecords')
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess: #开启一个会话
sess.run(init_op)
coord = tf.train.Coordinator() #多线程管理器
threads = tf.train.start_queue_runners(coord=coord)
for i in range(num_samples):
example, lab = sess.run(batch) #在会话中取出image 和label
img = Image.fromarray(example, 'RGB') #这里Image是之前提到的 Image.fromarray()实现array 到image的转换
img.save(gen_picture + '/' + str(i) + '_Label_' + str(lab) + '.jpg') #保存图片;注意cwd后面加上‘/’
coord.request_stop()
coord.join(threads)
sess.close()
结果展示
每一幅图片的命名中,第二个数字则是 label,吉娃娃都为1,哈士奇都为0;通过对照图片,可以发现图片分类正确。
Reference: https://blog.youkuaiyun.com/ywx1832990/article/details/78609323