1 数据分组
通过Dataset的batch对的对数据进行分组,即大量的数据集被分成多个小份数据,提高模型训练效率和利用多线程实现并行(分布式)计算,提高训练速度,接上一篇,CIFAR10选用了100张测试图片,分为4组,则每组有25个数据,即batch_size的尺寸.
import tensorflow as tf
import time
tf.reset_default_graph()
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname="/usr/share/fonts/truetype/arphic/ukai.ttc")
'''组数据尺寸'''
BATCH_SIZE = 25
def parse(record):
'''解析TFRecord数据.
返回:
features["image_raw"]:图像数据.
features["image_num"]:图像数量.
features["height"]:图像高度.
features["width"]:图像宽度.
'''
features = tf.parse_single_example(
record,
features={"image_raw":tf.FixedLenFeature([], tf.string),
"image_num":tf.FixedLenFeature([], tf.int64),
"height":tf.FixedLenFeature([], tf.int64),
"width":tf.FixedLenFeature([], tf.int64),
}
)
image_raw = features["image_raw"]
image_num = features["image_num"]
height = features["height"]
width = features["width"]
return image_raw
'''TFRecord文件路径.'''
input_files = ["./outputs/cifar10.tfrecords"]
dataset = tf.data.TFRecordDataset(input_files)
'''数据映射解析.'''
dataset = dataset.map(parse)
'''数据分组'''
dataset = dataset.batch(BATCH_SIZE)
'''Iterator初始化,若没有变量,也可使用iterator = dataset.make_one_shot_iterator().'''
iterator = dataset.make_initializable_iterator()
'''遍历数据集'''
images = iterator.get_next()
'''images iterator: Tensor("IteratorGetNext:0", shape=(), dtype=string)'''
print("images iterator: {}".format(images))
# images = iterator.get_next()
def show_image(image,i):
plt.imshow(image)
plt.title("fig{}".format(i))
def iterator_infinite(sess, images):
'''不清楚数据集大小情况下使用该遍历方法.'''
image_num = 0
while True:
try:
image = tf.decode_raw(images, tf.uint8)
image_shape = sess.run(image)
image_num += 1
print("The {}th image, image column vector: {}".format(image_num, image_shape.shape))
except tf.errors.OutOfRangeError:
print("Total iamge number: {}".format(image_num))
break
def iterator_loop(sess, images, nums):
'''指定遍历次数(已获取数据量)
参数:
sess: 会话
images: 数据张量.
nums: 组数.
总共100张图片, 每组25个,分成4组
返回:
数据集张量列表,长度为组数.
'''
image_batch = []
for i in range(nums):
image = tf.decode_raw(images, tf.uint8)
image_batch.append(image)
return image_batch
with tf.Session() as sess:
start_time = time.time()
sess.run(iterator.initializer)
image_batch = iterator_loop(sess, images, 4)
for order, image in enumerate(image_batch):
image = sess.run(image)
print("image column vector: {}".format(image.shape))
plt.figure(figsize=(5, 5))
plt.suptitle("第{}组数据".format(order+1), fontproperties=font, x=0.5, y=0.99, fontsize=15)
for i in range(25):
'''每组25张图片''
reshape_image = tf.reshape(image[i], [28, 28, 3])
plt.subplot(5, 5, i+1).set_title("fig{}".format(i+1))
plt.subplots_adjust(hspace=0.5)
plt.axis("off")
plt.imshow(reshape_image.eval())
plt.show()
end_time = time.time()
cost_time = end_time - start_time
print("Time costed: {}".format(cost_time))
结果1:


结果2:


运行时间
Time costed: 22.302716493606567
2 多线程读取数据
import tensorflow as tf
import time
tf.reset_default_graph()
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname="/usr/share/fonts/truetype/arphic/ukai.ttc")
'''组数据尺寸'''
BATCH_SIZE = 25
def parse(record):
'''解析TFRecord数据.
返回:
features["image_raw"]:图像数据.
features["image_num"]:图像数量.
features["height"]:图像高度.
features["width"]:图像宽度.
'''
features = tf.parse_single_example(
record,
features={"image_raw":tf.FixedLenFeature([], tf.string),
"image_num":tf.FixedLenFeature([], tf.int64),
"height":tf.FixedLenFeature([], tf.int64),
"width":tf.FixedLenFeature([], tf.int64),
}
)
image_raw = features["image_raw"]
image_num = features["image_num"]
height = features["height"]
width = features["width"]
return image_raw
'''TFRecord文件路径.'''
input_files = ["./outputs/cifar10.tfrecords"]
dataset = tf.data.TFRecordDataset(input_files)
'''数据映射解析.'''
dataset = dataset.map(parse)
'''数据分组'''
dataset = dataset.batch(BATCH_SIZE)
'''Iterator初始化,若没有变量,也可使用iterator = dataset.make_one_shot_iterator().'''
iterator = dataset.make_initializable_iterator()
'''遍历数据集'''
images = iterator.get_next()
'''images iterator: Tensor("IteratorGetNext:0", shape=(), dtype=string)'''
print("images iterator: {}".format(images))
# images = iterator.get_next()
def show_image(image,i):
plt.imshow(image)
plt.title("fig{}".format(i))
def iterator_infinite(sess, images):
'''不清楚数据集大小情况下使用该遍历方法.'''
image_num = 0
while True:
try:
image = tf.decode_raw(images, tf.uint8)
image_shape = sess.run(image)
image_num += 1
print("The {}th image, image column vector: {}".format(image_num, image_shape.shape))
except tf.errors.OutOfRangeError:
print("Total iamge number: {}".format(image_num))
break
def iterator_loop(sess, images, nums):
'''指定遍历次数(已获取数据量)
参数:
sess: 会话
images: 数据张量.
nums: 组数.
总共100张图片, 每组25个,分成4组
返回:
数据集张量列表,长度为组数.
'''
image_batch = []
for i in range(nums):
image = tf.decode_raw(images, tf.uint8)
image_batch.append(image)
return image_batch
with tf.Session() as sess:
start_time = time.time()
'''开启协程'''
coord = tf.train.Coordinator()
'''开启线程'''
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(iterator.initializer)
image_batch = iterator_loop(sess, images, 4)
for order, image in enumerate(image_batch):
image = sess.run(image)
print("image column vector: {}".format(image.shape))
plt.figure(figsize=(5, 5))
plt.suptitle("第{}组数据".format(order+1), fontproperties=font, x=0.5, y=0.99, fontsize=15)
for i in range(25):
'''每组25张图片'''
reshape_image = tf.reshape(image[i], [28, 28, 3])
plt.subplot(5, 5, i+1).set_title("fig{}".format(i+1))
plt.subplots_adjust(hspace=0.5)
plt.axis("off")
plt.imshow(reshape_image.eval())
plt.show()
'''终止线程'''
coord.request_stop()
'''线程lock:维护线程生命周期,当前线程任务执行结束,才开启下一个线程任务'''
coord.join(threads)
end_time = time.time()
cost_time = end_time - start_time
print("Time costed: {}".format(cost_time))
运行时间
Time costed: 17.85188627243042
3 对比:不分组不使用多线程
import tensorflow as tf
tf.reset_default_graph()
import matplotlib.pyplot as plt
import time
import cv2
BATCH_SIZE = 100
def parse(record):
features = tf.parse_single_example(
record,
features={"image_raw":tf.FixedLenFeature([], tf.string),
"image_num":tf.FixedLenFeature([], tf.int64),
"height":tf.FixedLenFeature([], tf.int64),
"width":tf.FixedLenFeature([], tf.int64),
}
)
return features["image_raw"],features["image_num"], features["height"],features["width"]
input_files = ["./outputs/chinese_bai.tfrecords"]
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parse)
iterator = dataset.make_initializable_iterator()
images, num, height, width = iterator.get_next()
print("images iterator: {}".format(images))
def iterator_data_subplot(sess, num, images, height, width):
plt.figure(figsize=(20, 20))
for i in range(num):
# '''Recovery data.'''
image = tf.decode_raw(images, tf.uint8)
height = tf.cast(height, tf.int32)
width = tf.cast(width, tf.int32)
image = tf.reshape(image, [height, width, 3])
image = sess.run(image)
plt.subplot(20,20,i+1).set_title("fig{}".format(i+1))
plt.subplots_adjust(hspace=0.8)
plt.axis("off")
plt.imshow(image)
plt.show()
def iterator_data_plot(sess, num, images, height, width):
plt.figure()
for i in range(num):
image = tf.decode_raw(images, tf.uint8)
height = tf.cast(height, tf.int32)
width = tf.cast(width, tf.int32)
image = tf.reshape(image, [height, width, 3])
image = sess.run(image)
plt.title("fig{}".format(i+1))
plt.axis("off")
plt.imshow(image)
plt.show()
def without_batch():
with tf.Session() as sess:
start_time = time.time()
sess.run(iterator.initializer)
print("image number: {}".format(num))
iterator_data_subplot(sess, 400, images, height, width)
end_time = time.time()
cost_time = end_time - start_time
print("Time costed: {}".format(cost_time))
if __name__ == "__main__":
without_batch()
运行时间
Time costed: 34.26711893081665
4 总结
序号 | batch分组 | threads多线程 | 运行时间/秒 |
---|---|---|---|
1 | × | × | 34.26711893081665 |
2 | √ | × | 22.302716493606567 |
3 | √ | √ | 17.85188627243042 |
(1) 对数据分组(batch)可提高模型训练效率,即把大量数据进行分组,每次训练读入组内数据,数据读取比不分组读取速度快;
(2) 多线程处理提高数据读取速度,可实现并行计算,同样分组,使用多线程读取比不开启线程读取速度快;
(3) 分组和多线程是训练过程的加速工具;
关于线程,协程,参考博客:
Tensorflow线程分析
Python之线程threading
Python之多进程multiprocessing
[参考文献]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/train/Coordinator
[2]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/train/start_queue_runners