使用TensorFlow的TFRecord来保存和读取数据
在网上看别人写的程序,总觉得已经明白了,没想到自己来写,发现有好多的坑:
- TFRecordWrite/Read不支持中文路径(至少我这里)
- 一定要注意tf.train.Example的层次结构
- Features
- feature1
- {key=string, value=tf.train.Feature()}
- {key, value}
- feature2
- feature1
- Features
- 我使用parse_single_example来解析TFRecordReader后的记录
- features一定要与TFRecordWriter对应
- 在使用string_input_producer来读记录的时候,如果设置了num_epochs的值,一定还要使用local_variables_initializer初始化一次(仅仅用global_variables_initializer会报错)
代码如下
本代码目的,CSV文件中包含JPG文件名和类型,需要按照一定比例分别创建Train, Validation和Test数据集,集合文件为TFRecordFormat,然后用TFRecordReader方式读出来:
“`python
import io
import sys
import os
import csv
import random
from PIL import Imge
#filepath = 'D:/文件/数据资料/图像数据/照片/正面分割缩小'
#TFRecord不支持中文路径
filepath = 'D:\\temp'
csvfilename='BL2-CO3.csv'
#三种集合的文件名
trainfile = 'train.tfrecord'
valfile = 'validation.tfrecord'
testfile = 'test.tfrecord'
imgwidth = 1024
imgheight = 384
#创建数据集合
def CreateData():
#设置各集合的比例
trainratio = 0.7
valratio = 0.15
testratio = 0.15
#读取CSV文件,第一列是文件名,第二列是类型
fileList= []
with open(os.path.join(filepath, csvfilename), newline='') as csvfile:
spamreader = csv.reader(csvfile, delimiter=',')
isheader = True
for row in spamreader:
if isheader==False:
fileList.append([row[0], int(row[1])])
isheader = False;
#打乱顺序
filecount = len(fileList)
random.shuffle(fileList)
#创建训练、验证和测试集合
index1 = round(filecount * trainratio)
trainlist = fileList[0:index1]
index2 = round(filecount*valratio) + index1;
vallist = fileList[index1:index2]
testlist = fileList[index2:]
#保存列表到tfrecord文件
def SaveDataSet(datalist, savepath):
writer = tf.python_io.TFRecordWriter(savepath)
for item in datalist:
img = Image.open(item[0])
img.load()
#定义label和image两个属性
#使用flatten()将多维图像数据扁平化[height, width, channel] -->[byte]
#使用tostring()转换为byte array
record = tf.train.Example(
features = tf.train.Features(
feature = {
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[item[1]])),
'image':tf.train.Feature(bytes_list= tf.train.BytesList(value=[np.asarray(img, dtype=np.uint8).flatten().tostring()]))
}
)
)
writer.write(record.SerializeToString())
writer.close()
SaveDataSet(trainlist, os.path.join(filepath, trainfile))
SaveDataSet(vallist, os.path.join(filepath, valfile))
SaveDataSet(testlist, os.path.join(filepath, testfile))
#读取数据集
def TestData(filename, epochs):
#读取一条记录 (在线程里面调用)
def ReadOneRecord(filename, epochs):
#创建文件名队列,这里只有一个文件
filename_queue = tf.train.string_input_producer([filename], num_epochs=epochs)
#开始读取记录
reader = tf.TFRecordReader()
_,record = reader.read(filename_queue)
#解析记录
features = tf.parse_single_example(
record,
features={
'label':tf.FixedLenFeature([], tf.int64),
'image':tf.FixedLenFeature([],tf.string)
}
)
#解析出来的内容
label = features['label']
image = tf.decode_raw(features['image'], tf.uint8)
image = tf.reshape(image, [imgheight, imgwidth, 3])
return label, image
#获取所有集合数据
def Feech_Data(sess, coord, threads, label_batch, image_batch):
try:
while not coord.should_stop():
#如果没有取完数据(num_epochs没有结束)
label, image = sess.run([label_batch, imag e_batch])
#取数据
print('labelsize=%d, imagesize=%d'%(label.size, image.size))
except tf.errors.OutOfRangeError: #表示num_epoch结束了
print('done training')
finally:
coord.request_stop()
#等待所有线程结束
coord.join(threads)
###获取数据##
#引用记录解析过程
label, image = ReadOneRecord(filename, epochs)
#TF批次取数据,capacity和min_after_dequeue参数设置比较重要
#网上有不少讲如何设置的,我只是实验,没有讲究
label_batch, image_batch = tf.train.shuffle_batch([label, image], batch_size=5, capacity=30, min_after_dequeue=10, num_threads=2)
#获取shuffle_batch创建的线程
#queues = tf.get_collection(tf.GraphKeys.QUEUE_RUS)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
#一定要运行一下这个初始化,否则num_epochs报错
init = tf.local_variables_initializer()
sess.run(init)
#控制数据获取线程是否结束
coord = tf.train.Coordinator()
#启动数据获取线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
Feech_Data(sess=sess, coord=coord, threads=threads, label_batch=label_batch, image_batch=image_batch)
sess.close()
###### 1=创建数据, 0=获取数据######
if 0 :
CreateData()
else:
TestData(os.path.join(filepath, trainfile), 50)