TensorFlow学习(一)

本文通过一个具体案例介绍如何使用TensorFlow的TFRecord格式保存和读取数据。包括创建训练、验证和测试数据集,并通过TFRecordReader进行数据读取。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用TensorFlow的TFRecord来保存和读取数据

在网上看别人写的程序,总觉得已经明白了,没想到自己来写,发现有好多的坑:

  • TFRecordWrite/Read不支持中文路径(至少我这里)
  • 一定要注意tf.train.Example的层次结构
    • Features
      • feature1
        • {key=string, value=tf.train.Feature()}
        • {key, value}
      • feature2
  • 我使用parse_single_example来解析TFRecordReader后的记录
    • features一定要与TFRecordWriter对应
  • 在使用string_input_producer来读记录的时候,如果设置了num_epochs的值,一定还要使用local_variables_initializer初始化一次(仅仅用global_variables_initializer会报错)

代码如下

本代码目的,CSV文件中包含JPG文件名和类型,需要按照一定比例分别创建Train, Validation和Test数据集,集合文件为TFRecordFormat,然后用TFRecordReader方式读出来:

“`python

import io
import sys
import os
import csv
import random
from PIL import Imge

#filepath = 'D:/文件/数据资料/图像数据/照片/正面分割缩小'
#TFRecord不支持中文路径
filepath = 'D:\\temp'

csvfilename='BL2-CO3.csv'
#三种集合的文件名
trainfile = 'train.tfrecord'
valfile = 'validation.tfrecord'
testfile = 'test.tfrecord'

imgwidth = 1024
imgheight = 384

#创建数据集合
def CreateData():
    #设置各集合的比例
    trainratio = 0.7
    valratio = 0.15
    testratio = 0.15

    #读取CSV文件,第一列是文件名,第二列是类型
    fileList= []
    with open(os.path.join(filepath, csvfilename), newline='') as csvfile:
        spamreader = csv.reader(csvfile, delimiter=',')
        isheader = True
        for row in spamreader:
            if isheader==False:
                fileList.append([row[0], int(row[1])])
            isheader = False;

    #打乱顺序
    filecount = len(fileList)
    random.shuffle(fileList)

    #创建训练、验证和测试集合
    index1 = round(filecount * trainratio)
    trainlist = fileList[0:index1]

    index2 = round(filecount*valratio) + index1;
    vallist = fileList[index1:index2]
    testlist = fileList[index2:]        

    #保存列表到tfrecord文件
    def SaveDataSet(datalist, savepath):
        writer = tf.python_io.TFRecordWriter(savepath)
        for item in datalist:
           img = Image.open(item[0])
           img.load()

           #定义label和image两个属性
           #使用flatten()将多维图像数据扁平化[height, width, channel] -->[byte]
           #使用tostring()转换为byte array
           record = tf.train.Example(
               features = tf.train.Features(
                    feature = {
                        'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[item[1]])),
                        'image':tf.train.Feature(bytes_list= tf.train.BytesList(value=[np.asarray(img, dtype=np.uint8).flatten().tostring()]))
                        }
                    )
                )
            writer.write(record.SerializeToString())

        writer.close()

    SaveDataSet(trainlist, os.path.join(filepath, trainfile))
    SaveDataSet(vallist, os.path.join(filepath, valfile))
    SaveDataSet(testlist, os.path.join(filepath, testfile))

#读取数据集
def TestData(filename, epochs):
    #读取一条记录 (在线程里面调用)
    def ReadOneRecord(filename, epochs):
        #创建文件名队列,这里只有一个文件
        filename_queue = tf.train.string_input_producer([filename], num_epochs=epochs)  
        #开始读取记录
        reader = tf.TFRecordReader()
        _,record = reader.read(filename_queue)
    #解析记录
        features = tf.parse_single_example(
            record,
            features={
                'label':tf.FixedLenFeature([], tf.int64),
                'image':tf.FixedLenFeature([],tf.string)
                }
            )

        #解析出来的内容
        label = features['label']
        image = tf.decode_raw(features['image'], tf.uint8)
        image = tf.reshape(image, [imgheight, imgwidth, 3])

        return label, image

    #获取所有集合数据
    def Feech_Data(sess, coord, threads, label_batch, image_batch):
        try:
            while not coord.should_stop():
                #如果没有取完数据(num_epochs没有结束)
                label, image = sess.run([label_batch, imag  e_batch])
                #取数据
                print('labelsize=%d, imagesize=%d'%(label.size, image.size))
        except tf.errors.OutOfRangeError:  #表示num_epoch结束了
            print('done training')
        finally:
            coord.request_stop()

        #等待所有线程结束
        coord.join(threads)

    ###获取数据##

    #引用记录解析过程
    label, image = ReadOneRecord(filename, epochs)
    #TF批次取数据,capacity和min_after_dequeue参数设置比较重要
    #网上有不少讲如何设置的,我只是实验,没有讲究
    label_batch, image_batch = tf.train.shuffle_batch([label, image], batch_size=5, capacity=30, min_after_dequeue=10, num_threads=2)

    #获取shuffle_batch创建的线程
    #queues = tf.get_collection(tf.GraphKeys.QUEUE_RUS)

    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    #一定要运行一下这个初始化,否则num_epochs报错
    init = tf.local_variables_initializer()
    sess.run(init)

    #控制数据获取线程是否结束       
    coord = tf.train.Coordinator()

    #启动数据获取线程
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    Feech_Data(sess=sess, coord=coord, threads=threads, label_batch=label_batch, image_batch=image_batch)

    sess.close()

###### 1=创建数据, 0=获取数据######
if 0 :
    CreateData()
else:
    TestData(os.path.join(filepath, trainfile), 50)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值