在学习tensorflow时,需要训练自己的数据,则需要自己写代码读入数据、组装成batch。主要分为两个步骤
我的数据为三维的图像数据,后缀为nii。一般的二维图像数据也差不多是这个处理步骤
一:制作TFrecord文件。
假如目前在本地磁盘中有两类数据,已经按8:2的比例分为训练集和测试集,目录结构:
/train
/AD
/NC
/test
/AD
/NC
制作训练集或测试集时需要将两类数据打乱,否则文件中前几百个全是AD,后几百个全是NC。
第一步:得到AD和NC中的每个数据的绝对路径列表,设置标签NC为0,AD为1。连接列表后打乱列表(即打乱数据)。
第二步:加载列表中每个数据和对应标签,写入TFrecord文件。可以全部写入一个TFrecord文件,若数据量大也可以写入多个TFrecord文件,这里每300个数据写入一个TFrecord文件。
import os
import tensorflow as tf
import numpy as np
import nibabel as nib
#AD\NC二分类
#将训练集或测试集数据随机打乱放在一起
def create_TFrecord(file_dir,tfrecord_dir):
####################################################################
#############得到所有数据的路径和标签列表################################
#存放数据的路径和标签列表
AD=[]
label_AD=[]
NC=[]
label_NC=[]
for file in os.listdir(file_dir+'/NC'):
NC.append(file_dir+'/NC'+'/'+file)
label_NC.append(0)
for file in os.listdir(file_dir+'/AD'):
AD.append(file_dir+'/AD'+'/'+file)
label_AD.append(1)
#将两个列表连接在一起
image_list=np.hstack((NC,AD))
label_list = np.hstack((label_NC,label_AD))
#转化为数组并转置
temp=np.array([image_list,label_list])
temp=temp.transpose()
#按行随机打乱
np.random.shuffle(temp)
images=list(temp[:,0])
all_label=list(temp[:,1])
labels = [int(float(i)) for i in all_label]
#################################################################
##############每300个数据放入一个TFrecord文件######################
file_num=0 #tfrecord文件计数
data_num=300 #一个tfrecord文件中数据个数
count=0 #数据计数
recordname=record_path+"-%d"%(file_num)
writer= tf.python_io.TFRecordWriter(recordname)# 创建一个writer
for image_path,label in zip(images,labels):
count=count+1
print(count,data_num)
if count>data_num:
count=0
file_num=file_num+1
recordname=record_path+"-%d"%(file_num)
writer= tf.python_io.TFRecordWriter(recordname)# 创建一个writer
#加载数据
img= nib.load(image_path).get_fdata()
#print(img.shape)
#转换为字节
img_raw=img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close() # 关闭writer
if __name__=="__main__":
dir="i:/ADNI/"
for dataset in os.listdir(dir):
filedir=dir+'dataset'
record_path="i:/AD_NC_%s.tfrecords"%dataset
create_TFrecord(filedir,record_path)
用几个数据测试,运行代码得到如下的文件
代码解析:
image_list=np.hstack((NC,AD)) 将两个列表连接成一个,如NC=[1,2,3],AD=[4,5,6],则 image_list=[1,2,3,4,5,6]
img= nib.load(image_path).get_fdata()是加载后缀为.nii的图像的函数,可以换成加载自己图像后缀的函数。加载为矩阵
二:读取batch数据
制作好自己的数据TFrecord文件后,尝试读取。在神经网络训练中一般都是一个batch送入网络训练,所以get_batch函数负责从TFrecord文件中读取数据并组合成一个batch,一次返回一组数据以及标签。
import tensorflow as tf
def get_batch(file_dir,batch_size): #默认一次100个数据
files=tf.train.match_filenames_once(file_dir)
filename_queue=tf.train.string_input_producer(files) #不随机打乱
#解析TFRecord文件里的数据
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(serialized_example,
features={
'img_raw': tf.FixedLenFeature([],tf.string),
'label': tf.FixedLenFeature([],tf.int64)
})
#得到图像数据、标签。
image,label=features['img_raw'],features['label']
#从原始图像数据解析出像素矩阵,并根据图像尺寸还原图像
#原数据为float64,解码也要float64
decode_image=tf.decode_raw(image,tf.float64)
decode_image=tf.reshape(decode_image,[140, 180, 150]) #我的数据尺寸
decode_image = tf.cast(decode_image, tf.float32)
decode_image=tf.expand_dims(decode_image,-1) #因为我是三维灰度数据,还需要增加一维(通道)
label = tf.cast(label, tf.int32)
label=tf.one_hot(label,2) #转化为one-hot编码
#将图像和标签数据整理成神经网络训练时需要的batch
# 抽取batch size个image、label
image_batch,label_batch = tf.train.batch([decode_image,label],
batch_size=batch_size,
capacity=3*batch_size,
)
#返回batch数据
return image_batch,label_batch
#测试是否能正确读出数据
def test(argv=None):
file_dir='i:/AD_NC_train.tfrecords-*'
image_batch,label_batch=get_batch(file_dir,1)
with tf.Session() as sess:
#变量初始化
tf.global_variables_initializer().run()
#由于train.match_filenames_once()返回的文件列表作为临时变量并没有保存到checkpoint,所以并不会作为全局变量被global_variables_initializer()函数初始化,
#所以要进行局部变量初始化,不然会报错
tf.local_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
try:
for i in range(5):
if coord.should_stop():
break
img, lab = sess.run([image_batch,label_batch])
#检查图像和标签形状
print(img.shape,lab.shape)
#更直观地检查就看能否正确显示图像
#for i in range(5):
# plt.imshow(img[i,:,:,30,0],cmap='gray') #选取切片查看,cmap='gray'显示灰度图
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
if __name__=='__main__':
tf.app.run(test)
运行代码测试,一次获得一个数据,循环5次
代码解析:
1.tf.train.match_filenames_once(file_dir)用于组合一个文件名列表,一般我们的数据可能不止一个TFrecord文件,比如file_dir=“data_train.tfrecords-*” :*
可以匹配所有类似的文件
2.tf.train.string_input_producer(files,shuffle=False):获得一个文件名队列,shuffle=False不打乱文件名列表
3.这段代码主要功能由get_batch提供,test()函数仅用于测试,在自己写代码时,应该测试保证读取无误后再进行下一步。
tf.app.run()用于启动主函数,这里将test作为主函数启动,参考这里
参考文章:https://blog.youkuaiyun.com/jyy555555/article/details/80283219