Tensorflow中数据集的一些高层操作的API函数及其用法示例

一.数据集中常用的高层API函数:

  1. dataset.map(func)
  2. dataset.shuffle(buffer_size)#随机打乱数据的顺序,buffer_size参数等价于tf.train.shuffle_batch中的min_after_dequeue参数。作用是保证每次打乱数据时的最低数量,要是打乱的数量过少,则打乱作用就不大了
  3. dataset.batch(batch_size)#将数据组合成batch
  4. dataset.repeat(N)#将数据集重复N份,repeat只代表重复相同的处理操作,并不会记录前一个epoch的处理结果。

二.具体使用如下:

import tensorflow as tf

train_files = tf.train.match_filenames_once('tfrecords/train_file-*')
test_file = tf.train.match_filenames_once('tfrecords/test_file-*')

def preprocess_for_train(image,height,width,bbox):
    pass

def build_net(input):pass

def calc_loss(logit,label):pass

def parse(record):
    features = tf.parse_single_example(
        record,
        features={
            'image':tf.FixedLenFeature([],tf.string),
            'label':tf.FixedLenFeature([],tf.int64),
            'height':tf.FixedLenFeature([],tf.int64),
            'width':tf.FixedLenFeature([],tf.int64),
            'channels':tf.FixedLenFeature([],tf.int64)
        }
    )

    image_data = tf.decode_raw(features['image'],tf.uint8)
    image_data.set_shape([features['height'],features['width'],features['channels']])
    label = features['label']

    return image_data,label

if __name__ == '__main__':
    image_size = 299
    batch_size = 100
    shuffle_buffer = 10000

    #读取数据集
    dataset = tf.data.TFRecordDataset(train_files)
    #将tfrecord转化成image,label的格式
    dataset = dataset.map(parse)
    #对image进行预处理
    dataset = dataset.map(
        lambda image,label:(preprocess_for_train(image,image_size,image_size,None),label)
    )
    #将数据集的顺序打乱,shuffle_buffer指定了队列中最少的元素个数
    dataset = dataset.shuffle(shuffle_buffer)
    #指定每次从迭代器中读出的数据个数,默认为1
    dataset = dataset.batch(batch_size)
    num_epoches = 10
    #将数据集中的数据重复num_epoches次,由于之前使用了shuffle,因此每个副本的顺序都不一定相同
    dataset = dataset.repeat(num_epoches)

    iterator = dataset.make_initializable_iterator()
    image_batch,label_batch = iterator.get_next()

    learning_rate = 0.01
    #构建网络,得到结果
    logit = build_net(image_batch)
    #结算损失
    loss = calc_loss(logit,label_batch)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    test_dataset = tf.data.TFRecordDataset(test_file)
    test_dataset = test_dataset.map(parse).map(
        lambda image,label:(tf.image.resize_images(image,[image_size,image_size]),label)
    )
    test_dataset = test_dataset.batch(batch_size)

    test_iterator = test_dataset.make_initializable_iterator()
    test_image_batch,test_label_batch = test_dataset.get_next()

    test_logit = build_net(test_image_batch)
    prediction = tf.argmax(test_logit,1)

    with tf.Session() as sess:
        #使用filename_match_once函数要初始化local_variables
        #使用迭代器要初始化iterator.initializer
        #global_variables一般都会初始化
        sess.run([tf.global_variables_initializer(),tf.local_variables_initializer(),iterator.initializer])

        while True:
            try:
                sess.run(train_step)
            except:
                break

        sess.run(test_iterator.initializer)
        test_results = []
        test_labels = []
        while True:
            try:
                pred,label = sess.run([prediction,test_label_batch])
                test_results.extend(pred)
                test_labels.extend(label)
            except:
                break

        correct = [float(y==y_) for (y,y_) in zip(test_results,test_labels)]
        acc = sum(correct)/len(correct)

        print(acc)

!!!上面例子中的循环体并不是指定循环运行次数,而是使用while(True) try-except的形式来将所有数据遍历一遍(即一个epoch)。这是因为在动态指定输入数据时,不同来源的数据数量大小难以预知,而这个方法我们不必提前知道数据量的大小。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值