tensorflow使用自己的数据训练网络(一)

本文介绍在学习TensorFlow时,训练自有三维图像数据(后缀为nii)的处理步骤。主要分为两步,一是制作TFrecord文件,需打乱数据后写入;二是读取batch数据,get_batch函数负责从TFrecord文件中读取并组合成batch,还给出了相关代码解析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在学习tensorflow时,需要训练自己的数据,则需要自己写代码读入数据、组装成batch。主要分为两个步骤

我的数据为三维的图像数据,后缀为nii。一般的二维图像数据也差不多是这个处理步骤

一:制作TFrecord文件。

假如目前在本地磁盘中有两类数据,已经按8:2的比例分为训练集和测试集,目录结构:
/train
     /AD
     /NC
/test
     /AD
     /NC
制作训练集或测试集时需要将两类数据打乱,否则文件中前几百个全是AD,后几百个全是NC。
第一步:得到AD和NC中的每个数据的绝对路径列表,设置标签NC为0,AD为1。连接列表后打乱列表(即打乱数据)。
第二步:加载列表中每个数据和对应标签,写入TFrecord文件。可以全部写入一个TFrecord文件,若数据量大也可以写入多个TFrecord文件,这里每300个数据写入一个TFrecord文件。

import os 
import tensorflow as tf 
import numpy as np
import nibabel as nib
#AD\NC二分类
#将训练集或测试集数据随机打乱放在一起
def create_TFrecord(file_dir,tfrecord_dir):
####################################################################
#############得到所有数据的路径和标签列表################################
    #存放数据的路径和标签列表
    AD=[]
    label_AD=[]
    NC=[]
    label_NC=[]

    for file in os.listdir(file_dir+'/NC'):
        NC.append(file_dir+'/NC'+'/'+file)
        label_NC.append(0)

    for file in os.listdir(file_dir+'/AD'):
        AD.append(file_dir+'/AD'+'/'+file)
        label_AD.append(1)        

    #将两个列表连接在一起
    image_list=np.hstack((NC,AD))
    label_list = np.hstack((label_NC,label_AD))
    #转化为数组并转置
    temp=np.array([image_list,label_list])
    temp=temp.transpose()
    #按行随机打乱
    np.random.shuffle(temp)
    images=list(temp[:,0])
    all_label=list(temp[:,1])
    labels = [int(float(i)) for i in all_label]
    
    #################################################################
    ##############每300个数据放入一个TFrecord文件######################
    file_num=0   #tfrecord文件计数
    data_num=300 #一个tfrecord文件中数据个数
    count=0    #数据计数
    recordname=record_path+"-%d"%(file_num)
    writer= tf.python_io.TFRecordWriter(recordname)# 创建一个writer
    for image_path,label in zip(images,labels):
        
        count=count+1
        print(count,data_num)
        if count>data_num:
            count=0
            file_num=file_num+1
            recordname=record_path+"-%d"%(file_num)
            writer= tf.python_io.TFRecordWriter(recordname)# 创建一个writer    
        #加载数据         
        img= nib.load(image_path).get_fdata()
        #print(img.shape)
        #转换为字节
        img_raw=img.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
                                                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                                                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                                                }))
        writer.write(example.SerializeToString())
    writer.close()  # 关闭writer

if __name__=="__main__":
    dir="i:/ADNI/"
    for dataset in os.listdir(dir):
    	filedir=dir+'dataset'
    	record_path="i:/AD_NC_%s.tfrecords"%dataset
   		create_TFrecord(filedir,record_path)

用几个数据测试,运行代码得到如下的文件
在这里插入图片描述
代码解析:
image_list=np.hstack((NC,AD)) 将两个列表连接成一个,如NC=[1,2,3],AD=[4,5,6],则 image_list=[1,2,3,4,5,6]
img= nib.load(image_path).get_fdata()是加载后缀为.nii的图像的函数,可以换成加载自己图像后缀的函数。加载为矩阵

二:读取batch数据

制作好自己的数据TFrecord文件后,尝试读取。在神经网络训练中一般都是一个batch送入网络训练,所以get_batch函数负责从TFrecord文件中读取数据并组合成一个batch,一次返回一组数据以及标签。

import tensorflow as tf
 
def get_batch(file_dir,batch_size): #默认一次100个数据
    files=tf.train.match_filenames_once(file_dir)
    filename_queue=tf.train.string_input_producer(files) #不随机打乱
     
    #解析TFRecord文件里的数据
    reader=tf.TFRecordReader()
    _,serialized_example=reader.read(filename_queue)
    features=tf.parse_single_example(serialized_example,
    		features={
    			'img_raw': tf.FixedLenFeature([],tf.string),
    			'label': tf.FixedLenFeature([],tf.int64)
    			})
    		  
    #得到图像数据、标签。
    image,label=features['img_raw'],features['label']
     
    #从原始图像数据解析出像素矩阵,并根据图像尺寸还原图像
    #原数据为float64,解码也要float64
    decode_image=tf.decode_raw(image,tf.float64)
    decode_image=tf.reshape(decode_image,[140, 180, 150]) #我的数据尺寸
    decode_image = tf.cast(decode_image, tf.float32)
    decode_image=tf.expand_dims(decode_image,-1)  #因为我是三维灰度数据,还需要增加一维(通道)
        
    label = tf.cast(label, tf.int32)
    label=tf.one_hot(label,2)     #转化为one-hot编码
            
    	#将图像和标签数据整理成神经网络训练时需要的batch	
        # 抽取batch size个image、label
    image_batch,label_batch = tf.train.batch([decode_image,label],         
                                                         batch_size=batch_size,
                                                         capacity=3*batch_size,
                                                         )
    #返回batch数据
    return image_batch,label_batch
    
#测试是否能正确读出数据
def test(argv=None):
    file_dir='i:/AD_NC_train.tfrecords-*'
        
    image_batch,label_batch=get_batch(file_dir,1)
    with tf.Session() as sess:
        #变量初始化
        tf.global_variables_initializer().run()
        #由于train.match_filenames_once()返回的文件列表作为临时变量并没有保存到checkpoint,所以并不会作为全局变量被global_variables_initializer()函数初始化,
        #所以要进行局部变量初始化,不然会报错
        tf.local_variables_initializer().run()
        		
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        try:
            for i in range(5):
                if coord.should_stop():
                    break
                img, lab = sess.run([image_batch,label_batch])
                        #检查图像和标签形状
                print(img.shape,lab.shape)
                   
                    #更直观地检查就看能否正确显示图像
                    #for i in range(5):
                    #   plt.imshow(img[i,:,:,30,0],cmap='gray') #选取切片查看,cmap='gray'显示灰度图                        
        except tf.errors.OutOfRangeError:
            print('done!')
        finally:
            coord.request_stop()
        coord.join(threads)
if __name__=='__main__':
    tf.app.run(test)

运行代码测试,一次获得一个数据,循环5次

在这里插入图片描述
代码解析:
1.tf.train.match_filenames_once(file_dir)用于组合一个文件名列表,一般我们的数据可能不止一个TFrecord文件,比如file_dir=“data_train.tfrecords-*” :*可以匹配所有类似的文件

2.tf.train.string_input_producer(files,shuffle=False):获得一个文件名队列,shuffle=False不打乱文件名列表

3.这段代码主要功能由get_batch提供,test()函数仅用于测试,在自己写代码时,应该测试保证读取无误后再进行下一步。
tf.app.run()用于启动主函数,这里将test作为主函数启动,参考这里

参考文章:https://blog.youkuaiyun.com/jyy555555/article/details/80283219

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值