Tensorflow 数据对象Dataset.shuffle()、repeat()、batch() 等用法

本文详细介绍了TensorFlow中的Dataset类及其常用方法,包括from_tensor_slices、from_tensors和from_generator,以及数据转换操作如batch、shuffle和map。通过实例展示了如何创建和操作数据集,以及如何对数据进行批量处理、随机打乱和应用函数映射。同时强调了数据处理顺序的重要性,如先shuffle后batch以确保数据的随机性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.Dataset数据对象

Dataset可以用来表示输入管道元素集合(张量的嵌套结构)和“逻辑计划“对这些元素的转换操作。在Dataset中元素可以是向量,元组或字典等形式。
另外,Dataset需要配合另外一个类Iterator进行使用,Iterator对象是一个迭代器,可以对Dataset中的元素进行迭代提取。

2.Dataset方法

2.1 产生数据集
2.1.1. from_tensor_slices

from_tensor_slices 用于创建dataset,其元素是给定张量的切片的元素。

函数形式:from_tensor_slices(tensors)

参数tensors:张量的嵌套结构,每个都在第0维中具有相同的大小。

import tensorflow as tf
#创建一个Dataset对象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9,10,11,12])
'''合成批次'''
dataset=dataset.batch(5)
#创建一个迭代器
iterator = dataset.make_one_shot_iterator()
#get_next()函数可以帮助我们从迭代器中获取元素
element = iterator.get_next()
 
#遍历迭代器,获取所有元素
with tf.Session() as sess:   
    for i in range(9):
       print(sess.run(element))  

输出

[1 2 3 4 5]
[ 6  7  8  9 10]
[11 12]

2.1.2 .from_tensors

创建一个Dataset包含给定张量的单个元素。

函数形式:from_tensors(tensors)

参数tensors:张量的嵌套结构。

dataset = tf.data.Dataset.from_tensors([1,2,3,4,5,6,7,8,9])
 
iterator = concat_dataset.make_one_shot_iterator()
 
element = iterator.get_next()
 
with tf.Session() as sess:   
    for i in range(1):
       print(sess.run(element))

区别:

  • from_tensors是将tensors作为一个整体进行操纵,而from_tensor_slices可以操纵tensors里面的元素。

2.1.3 from_generator(具体实践不太了解)

创建Dataset由其生成元素的元素generator。

函数形式:from_generator(generator,output_types,output_shapes=None,args=None)

参数generator:一个可调用对象,它返回支持该iter()协议的对象 。如果args未指定,generator则不得参数; 否则它必须采取与有值一样多的参数args。
参数output_types:tf.DType对应于由元素生成的元素的每个组件的对象的嵌套结构generator。
参数output_shapes:tf.TensorShape 对应于由元素生成的元素的每个组件的对象 的嵌套结构generator
参数args:tf.Tensor将被计算并将generator作为NumPy数组参数传递的对象元组。

 

2.2 数据转换Transformation
2.2.1 batch

# 创建0-10的数据集,每个batch取个数6。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

但是如果我们把循环次数设置成3(即for i in range(2)),那么就会报错。

或者将for循环改为:while True:。就不用设置循环次数了。

 

2.2.2 shuffle


上面所有输出结果都是有序的,在机器学习中训练模型需要将数据打乱,这样可以保证每批次训练的时候所用到的数据集是不一样的,可以提高模型训练效果。

注意:shuffle的顺序很重要,应该先shuffle再batch,如果先batch后shuffle的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱(实际并未shuffle)。

随机混洗数据集的元素。

函数形式:shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

参数buffer_size:表示新数据集将从中采样的数据集中的元素数。

buffer_size=1:不打乱顺序,既保持原序
buffer_size越大,打乱程度越大
参数seed:(可选)表示将用于创建分布的随机种子。
参数reshuffle_each_iteration:(可选)一个布尔值,如果为true,则表示每次迭代时都应对数据集进行伪随机重组。(默认为True。)

在这里buffer_size:该函数的作用就是先构建buffer,大小为buffer_size,然后从dataset中提取数据将它填满。batch操作,从buffer中提取。

如果buffer_size小于Dataset的大小,每次提取buffer中的数据,会再次从Dataset中抽取数据将它填满(当然是之前没有抽过的)。所以一般最好的方式是buffer_size= Dataset_size。

 

2.2.3 map


map可以将map_func函数映射到数据集.
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,

函数形式:flat_map(map_func,num_parallel_calls=None)

参数map_func:映射函数
参数num_parallel_calls:表示要并行处理的数字元素。如果未指定,将按顺序处理元素。如果使用值tf.data.experimental.AUTOTUNE,则根据可用的CPU动态设置并行调用的数量。

对dataset中每个元素的值加10

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

[16 17 18 19]
[10 11 12 13 14 15]


2.2.4 repeat

重复此数据集次数,主要用来处理机器学习中的epoch,假设原先的数据训练一个epoch,使用repeat(2)就可以将之变成2个epoch,默认空是无限次。

函数形式:repeat(count=None)

参数count:(可选)表示数据集应重复的次数。默认行为(如果count是None或-1)是无限期重复的数据集。

 


————————————————
版权声明:本文为优快云博主「rrr2」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.youkuaiyun.com/qq_35608277/article/details/116333888

在使用TensorFlow进行深度学习模型训练时,有效地应用dataset.shuffledataset.batch以及dataset.repeat方法,可以极大地提升训练过程的效率和效果。为了深入理解这些方法的用法及其背后的原理,建议阅读《TensorFlowdataset.shuffledataset.batchrepeat用法解析》这篇文章。 参考资源链接:[TensorFlowdataset.shuffledataset.batchrepeat用法解析](https://wenku.csdn.net/doc/64534c7cea0840391e779466) 首先,数据的批处理(batching)是提高内存利用率和训练速度的关键步骤。通过`dataset.batch(batch_size)`,你可以将数据集分批处理,每个批次作为一个训练步骤的输入。批次大小的选择对模型的收敛速度和稳定性有很大影响。较小的批次大小可以提高模型的泛化能力,但较大的批次大小可以更有效地利用硬件加速。 其次,数据洗牌(shuffling)是确保模型不会过拟合的重要步骤,因为它可以防止模型学习到数据集中的任何特定顺序。使用`dataset.shuffle(buffer_size)`方法,可以在每个epoch开始前打乱数据,其中`buffer_size`的大小决定了内存中用于随机抽取样本的缓冲区大小。如果缓冲区较小,可能会导致数据洗牌不充分,从而影响模型的训练效果。 接着,重复数据集(repeating)是通过多次遍历数据集来增加训练周期,这对于小数据集特别重要。通过`dataset.repeat(num_epochs)`,可以指定数据集被重复的次数,模拟长时间的训练过程。注意,`repeat`应该在`batch`之前调用,以便每个epoch中的每个批次都是从随机化后的数据集开始。 通过合理设置这些参数,可以构建一个高效的数据输入管道,提升模型训练的效率。下面是一个简化的代码示例,展示了如何结合这些方法: ```python import tensorflow as tf # 假设我们有一些数据和标签 data = tf.constant([[1., 2.], [3., 4.]]) # 示例数据 labels = tf.constant([0., 1.]) # 示例标签 # 创建一个TensorFlow数据集 dataset = tf.data.Dataset.from_tensor_slices((data, labels)) # 设置缓冲区大小为2,对数据进行打乱 dataset = dataset.shuffle(buffer_size=2) # 将数据分为批次,每个批次包含1个样本 dataset = dataset.batch(batch_size=1) # 重复数据集2次 dataset = dataset.repeat(count=2) # 创建迭代器以访问数据 iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() # 使用会话来运行迭代器并获取数据 with tf.Session() as sess: for _ in range(8): # 应该会打印出4个批次,每个批次1个样本 print(sess.run(next_element)) ``` 在上述代码中,我们创建了一个简单的数据集,并按照`shuffle`、`batch`和`repeat`的顺序对数据集进行处理。通过这种方式,你可以更好地控制TensorFlow训练过程中数据的处理方式。 为了深入学习并理解如何在实际项目中应用这些技术,以及如何根据具体情况调整参数,强烈推荐阅读《TensorFlowdataset.shuffledataset.batchrepeat用法解析》这篇文章。它不仅详细解释了每个方法的作用和原理,还提供了实际应用的案例,帮助你更好地利用这些工具来优化你的深度学习模型训练流程。 参考资源链接:[TensorFlowdataset.shuffledataset.batchrepeat用法解析](https://wenku.csdn.net/doc/64534c7cea0840391e779466)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值