(二)TFRecord数据处理

本文介绍TensorFlow中数据分组(batch)与多线程处理的方法及效果对比,展示了如何通过数据分组和多线程提升模型训练效率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 数据分组

通过Dataset的batch对的对数据进行分组,即大量的数据集被分成多个小份数据,提高模型训练效率和利用多线程实现并行(分布式)计算,提高训练速度,接上一篇,CIFAR10选用了100张测试图片,分为4组,则每组有25个数据,即batch_size的尺寸.

import tensorflow as tf
import time
tf.reset_default_graph()
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname="/usr/share/fonts/truetype/arphic/ukai.ttc")
'''组数据尺寸'''
BATCH_SIZE = 25

def parse(record):
	'''解析TFRecord数据.
	返回:
	features["image_raw"]:图像数据.
	features["image_num"]:图像数量.
	features["height"]:图像高度.
	features["width"]:图像宽度.
	'''
    features = tf.parse_single_example(
        record,
        features={"image_raw":tf.FixedLenFeature([], tf.string),
                  "image_num":tf.FixedLenFeature([], tf.int64),
                  "height":tf.FixedLenFeature([], tf.int64),
                  "width":tf.FixedLenFeature([], tf.int64),
                 }
    )
    image_raw = features["image_raw"]
    image_num = features["image_num"]
    height = features["height"]
    width = features["width"]
    return image_raw

'''TFRecord文件路径.'''
input_files = ["./outputs/cifar10.tfrecords"]
dataset = tf.data.TFRecordDataset(input_files)
'''数据映射解析.'''
dataset = dataset.map(parse)
'''数据分组'''
dataset = dataset.batch(BATCH_SIZE)
'''Iterator初始化,若没有变量,也可使用iterator = dataset.make_one_shot_iterator().'''
iterator = dataset.make_initializable_iterator()
'''遍历数据集'''
images = iterator.get_next()
'''images iterator: Tensor("IteratorGetNext:0", shape=(), dtype=string)'''
print("images iterator: {}".format(images))
# images = iterator.get_next()
def show_image(image,i):
    plt.imshow(image)
    plt.title("fig{}".format(i))

def iterator_infinite(sess, images):
    '''不清楚数据集大小情况下使用该遍历方法.'''
    image_num = 0
    while True:
        try:
            image = tf.decode_raw(images, tf.uint8)
            image_shape = sess.run(image)
            image_num += 1
            print("The {}th image, image column vector: {}".format(image_num, image_shape.shape))
        except tf.errors.OutOfRangeError:
            print("Total iamge number: {}".format(image_num))
            break
def iterator_loop(sess, images, nums):
    '''指定遍历次数(已获取数据量)
    参数:
    sess: 会话
    images: 数据张量.
    nums: 组数.
    总共100张图片, 每组25个,分成4组
    返回:
    数据集张量列表,长度为组数.
    '''
    image_batch = []
    for i in range(nums):
        image = tf.decode_raw(images, tf.uint8)
        image_batch.append(image)
        
    return image_batch
         
with tf.Session() as sess:
	start_time = time.time()
    sess.run(iterator.initializer)
    image_batch = iterator_loop(sess, images, 4)
    for order, image in enumerate(image_batch):
        image = sess.run(image)
        print("image column vector: {}".format(image.shape))
        plt.figure(figsize=(5, 5))
        plt.suptitle("第{}组数据".format(order+1), fontproperties=font, x=0.5, y=0.99, fontsize=15)
        for i in range(25):
        	'''每组25张图片''
            reshape_image = tf.reshape(image[i], [28, 28, 3])
            plt.subplot(5, 5, i+1).set_title("fig{}".format(i+1))
            plt.subplots_adjust(hspace=0.5)
            plt.axis("off") 
            plt.imshow(reshape_image.eval())
        plt.show()
	end_time = time.time()
	cost_time = end_time - start_time
	print("Time costed: {}".format(cost_time))

结果1:


在这里插入图片描述 在这里插入图片描述

图2.1 前两组

结果2:


在这里插入图片描述 在这里插入图片描述

图2.2 后两组

运行时间

Time costed: 22.302716493606567

2 多线程读取数据

import tensorflow as tf
import time
tf.reset_default_graph()
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname="/usr/share/fonts/truetype/arphic/ukai.ttc")
'''组数据尺寸'''
BATCH_SIZE = 25

def parse(record):
	'''解析TFRecord数据.
	返回:
	features["image_raw"]:图像数据.
	features["image_num"]:图像数量.
	features["height"]:图像高度.
	features["width"]:图像宽度.
	'''
    features = tf.parse_single_example(
        record,
        features={"image_raw":tf.FixedLenFeature([], tf.string),
                  "image_num":tf.FixedLenFeature([], tf.int64),
                  "height":tf.FixedLenFeature([], tf.int64),
                  "width":tf.FixedLenFeature([], tf.int64),
                 }
    )
    image_raw = features["image_raw"]
    image_num = features["image_num"]
    height = features["height"]
    width = features["width"]
    return image_raw

'''TFRecord文件路径.'''
input_files = ["./outputs/cifar10.tfrecords"]
dataset = tf.data.TFRecordDataset(input_files)
'''数据映射解析.'''
dataset = dataset.map(parse)
'''数据分组'''
dataset = dataset.batch(BATCH_SIZE)
'''Iterator初始化,若没有变量,也可使用iterator = dataset.make_one_shot_iterator().'''
iterator = dataset.make_initializable_iterator()
'''遍历数据集'''
images = iterator.get_next()
'''images iterator: Tensor("IteratorGetNext:0", shape=(), dtype=string)'''
print("images iterator: {}".format(images))
# images = iterator.get_next()
def show_image(image,i):
    plt.imshow(image)
    plt.title("fig{}".format(i))

def iterator_infinite(sess, images):
    '''不清楚数据集大小情况下使用该遍历方法.'''
    image_num = 0
    while True:
        try:
            image = tf.decode_raw(images, tf.uint8)
            image_shape = sess.run(image)
            image_num += 1
            print("The {}th image, image column vector: {}".format(image_num, image_shape.shape))
        except tf.errors.OutOfRangeError:
            print("Total iamge number: {}".format(image_num))
            break
def iterator_loop(sess, images, nums):
    '''指定遍历次数(已获取数据量)
    参数:
    sess: 会话
    images: 数据张量.
    nums: 组数.
    总共100张图片, 每组25个,分成4组
    返回:
    数据集张量列表,长度为组数.
    '''
    image_batch = []
    for i in range(nums):
        image = tf.decode_raw(images, tf.uint8)
        image_batch.append(image)
        
    return image_batch
         
with tf.Session() as sess:
	start_time = time.time()
	'''开启协程'''
	coord = tf.train.Coordinator()
	'''开启线程'''
	threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(iterator.initializer)
    image_batch = iterator_loop(sess, images, 4)
    for order, image in enumerate(image_batch):
        image = sess.run(image)
        print("image column vector: {}".format(image.shape))
        plt.figure(figsize=(5, 5))
        plt.suptitle("第{}组数据".format(order+1), fontproperties=font, x=0.5, y=0.99, fontsize=15)
        for i in range(25):
        	'''每组25张图片'''
            reshape_image = tf.reshape(image[i], [28, 28, 3])
            plt.subplot(5, 5, i+1).set_title("fig{}".format(i+1))
            plt.subplots_adjust(hspace=0.5)
            plt.axis("off") 
            plt.imshow(reshape_image.eval())
        plt.show()
    '''终止线程'''
    coord.request_stop()
    '''线程lock:维护线程生命周期,当前线程任务执行结束,才开启下一个线程任务'''
    coord.join(threads)
	end_time = time.time()
	cost_time = end_time - start_time
	print("Time costed: {}".format(cost_time))

运行时间

Time costed: 17.85188627243042

3 对比:不分组不使用多线程

import tensorflow as tf
tf.reset_default_graph()
import matplotlib.pyplot as plt
import time
import cv2

BATCH_SIZE = 100
def parse(record):
    features = tf.parse_single_example(
        record,
        features={"image_raw":tf.FixedLenFeature([], tf.string),
                  "image_num":tf.FixedLenFeature([], tf.int64),
                  "height":tf.FixedLenFeature([], tf.int64),
                  "width":tf.FixedLenFeature([], tf.int64),
                 }
    )
    return features["image_raw"],features["image_num"], features["height"],features["width"]
input_files = ["./outputs/chinese_bai.tfrecords"]
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parse)
iterator = dataset.make_initializable_iterator()
images, num, height, width = iterator.get_next()
print("images iterator: {}".format(images))

def iterator_data_subplot(sess, num, images, height, width):
    plt.figure(figsize=(20, 20))
    for i in range(num):
#         '''Recovery data.'''
        image = tf.decode_raw(images, tf.uint8)
     
        height = tf.cast(height, tf.int32)
        width = tf.cast(width, tf.int32)
        image = tf.reshape(image, [height, width, 3])
        image = sess.run(image)
        plt.subplot(20,20,i+1).set_title("fig{}".format(i+1))
        plt.subplots_adjust(hspace=0.8)
        plt.axis("off")
        plt.imshow(image)
    plt.show()

def iterator_data_plot(sess, num, images, height, width):
    plt.figure()
    for i in range(num):
        image = tf.decode_raw(images, tf.uint8)
        height = tf.cast(height, tf.int32)
        width = tf.cast(width, tf.int32)
        image = tf.reshape(image, [height, width, 3])
        image = sess.run(image)
        plt.title("fig{}".format(i+1))
        plt.axis("off")
        plt.imshow(image)
        plt.show()
def without_batch():
    with tf.Session() as sess:
        start_time = time.time()
        sess.run(iterator.initializer)   
        print("image number: {}".format(num))
        iterator_data_subplot(sess, 400, images, height, width)
        end_time = time.time()
        cost_time = end_time - start_time
        print("Time costed: {}".format(cost_time))
if __name__ == "__main__":
    without_batch()

运行时间

Time costed: 34.26711893081665

4 总结

序号batch分组threads多线程运行时间/秒
1××34.26711893081665
2×22.302716493606567
317.85188627243042

(1) 对数据分组(batch)可提高模型训练效率,即把大量数据进行分组,每次训练读入组内数据,数据读取比不分组读取速度快;
(2) 多线程处理提高数据读取速度,可实现并行计算,同样分组,使用多线程读取比不开启线程读取速度快;
(3) 分组和多线程是训练过程的加速工具;
关于线程,协程,参考博客:
Tensorflow线程分析
Python之线程threading
Python之多进程multiprocessing


[参考文献]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/train/Coordinator
[2]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/train/start_queue_runners


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天然玩家

坚持才能做到极致

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值