tf.data 如何构建TensorFlow输入流水线

本文详细介绍了如何使用tf.data构建TensorFlow的输入流水线,包括从内存数据、Python生成器、TFRecord文件、CSV数据等不同来源创建Dataset对象,以及批处理、映射等转换操作。通过实例展示了数据预处理的关键步骤和方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.data: 如何构建TensorFlow输入流水线(version1)

步骤一:构建Dataset对象

构建Dataset对象有两种方式

  • 从存储在内存中的数据、或者一个或多个文件中的数据构造Dataset对象
  • 从一个或多个Dataset对象中构造Dataset对象

有了一个 Dataset 对象之后,您可以通过链接 tf.data.Dataset 对象上的方法调用将其转换成一个新的 Dataset。例如,您可以应用逐元素转换(例如 Dataset.map)和多元素转换(例如 Dataset.batch)。有关完整的转换列表,请参阅 tf.data.Dataset 文档。

Dataset 对象是一个 Python 可迭代对象,可以利用for循环使用它的元素

例子:从内存中的数据构造一个Dataset对象,您可以使用 tf.data.Dataset.from_tensors()tf.data.Dataset.from_tensor_slices()。或者,如果您的输入数据以推荐的 TFRecord 格式存储在文件中,则您可以使用 tf.data.TFRecordDataset()

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
for elem in dataset:
  print(elem.numpy())  # numpy变量
  print(elem)  # tensor变量

或者使用 iter 显式创建一个 Python 迭代器,并利用 next 来使用它的元素:

it = iter(dataset)

print(next(it)) 
print(next(it))
print(next(it))
print(next(it))
tf.Tensor(8, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(8, shape=(), dtype=int32)

通过存储在内存中的数据构造Dataset对象,numpy数据

如果所有的输入数据都适合装入内存,那么从这些数据创建 Dataset 的最简单方式是将它们转换为 tf.Tensor 对象并使用 Dataset.from_tensor_slices

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset

这种方式的优缺点:

这对于小数据集来说效果很好,但是会浪费内存(因为数组的内容会被多次复制),并且可能会达到 tf.GraphDef 协议缓冲区的 2GB 上限。

通过python生成器构造Dataset对象,图片生成器

另一个可被轻松整合为 tf.data.Dataset 的常用数据源是 Python 生成器。

小心:虽然这种方式比较简便,但它的可移植性和可扩缩性有限。它必须在创建生成器的同一 Python 进程中运行,且仍受 Python GIL 约束。

Dataset.from_generator 构造函数会将 Python 生成器转换为具有完整功能的 tf.data.Dataset

构造函数会获取可调用对象作为输入,而非迭代器。这样,构造函数结束后便可重启生成器。构造函数会获取一个可选的 args 参数,作为可调用对象的参数。

output_types 参数是必需的,因为 tf.data 会在内部构建 tf.Graph,而计算图边缘需要 tf.dtype

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())

# take(num) 表示拿num次迭代器的元素,如果num大于迭代器总共生成的元素数量num_2,只输出num_2个元素
# batch(num) 表示从迭代器中每次一次性取出num个数据
# .batch(10).take(2) 表示从迭代器中最多取两次数据,每次数据一次性取10个元素
# .repeat().batch(10).take(10)  表示从迭代器中取十次数据,每次数据一次性取10个元素,如果迭代器总共包含的元素不够10*10个元素,则会重复获取迭代器
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

Dataset.from_generator 构造函数会将 Python 生成器转换为具有完整功能的 tf.data.Dataset

构造函数会获取可调用对象作为输入,而非迭代器。这样,构造函数结束后便可重启生成器。构造函数会获取一个可选的 args 参数,作为可调用对象的参数。

output_types 参数是必需的,因为 tf.data 会在内部构建 tf.Graph,而计算图边缘需要 tf.dtype

通过CSV数据

csv文件格式存储着结构化表格数据

例如:

# 下载一个示例文件
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
# 看一下示例文件长什么样子
df = pd.read_csv(titanic_file)
df.head()

如果您的数据适合存储在内存中,那么 Dataset.from_tensor_slices 方法对字典同样有效,使这些数据可以被轻松导入:

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():  # 遍历字典feature_batch,feature_batch相当于一行
    print("  {!r:20s}: {}".format(key, value))
 'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

更具可扩展性的方式是根据需要从磁盘加载。

tf.data 模块提供了从一个或多个符合 RFC 4180 的 CSV 文件提取记录的方法。

experimental.make_csv_dataset 函数是用来读取 CSV 文件集的高级接口。它支持列类型推断和许多其他功能,如批处理和重排,以简化使用。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, # csv文件名
    batch_size=4,  # 获取批次,相当于之前一次获取一行,现在一次获取4行
    label_name="survived",  # csv文件中,标签列对应的列名
    select_columns=['class', 'fare', 'survived']  # 想要获取csv中的哪些列,这里只需获取其中三个列
)
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [1 0 1 0]
features:
  'sex'               : [b'female' b'male' b'female' b'male']
  'age'               : [45. 25. 41. 28.]
  'n_siblings_spouses': [1 0 0 0]
  'parch'             : [1 0 1 0]
  'fare'              : [164.8667   7.8958  19.5      7.8958]
  'class'             : [b'First' b'Third' b'Second' b'Third']
  'deck'              : [b'unknown' b'unknown' b'unknown' b'unknown']
  'embark_town'       : [b'Southampton' b'Southampton' b'Southampton' b'Southampton']
  'alone'             : [b'n' b'y' b'n' b'y']

还有一个级别更低的 experimental.CsvDataset 类,该类可以提供粒度更细的控制,但它不支持列类型推断,您必须指定每个列的类型。

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])  # 结果是个列表,列表元素是tensor变量

[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

如果某些列为空,则此低级接口允许您提供默认值,而非列类型。

写入一个示例文件

%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,

打印结果

record_defaults = [999,999,999]  # 每个列对应为空值时应该填充的值
dataset = tf.data.experimental.CsvDataset("missing.csv", 
                                          record_defaults,
                                          select_cols=[0,1,2] # 表示再【0,1,2,3】列中只是选择了【1,3】列
                                        )

for line in dataset.take(10):
  print([item.numpy() for item in line])
[1, 2, 3, 4]
[999, 2, 3, 4]
[1, 999, 3, 4]
[1, 2, 999, 4]
[1, 2, 3, 999]
[999, 999, 999, 999]

tf.data: 如何构建TensorFlow输入流水线(version2)

构建TensorFlow输入流水线的三个步骤

  • 步骤一:使用不同的数据源构建Dataset对象

    • 使用numpy数组
    • 使用python生成器
    • 使用TFRecord数据
    • 使用文本数据
    • 使用csv数据
    • 使用文件集
  • 步骤二:调用Dataset对象的方法将其转换成一个新的Dataset

    • 如批处理,Dataset.batch()
    • 如逐元素转换,Dataset.map()
    • 等等
  • 步骤三:Dataset对象是一个python可迭代对象,通过for循环使用它的元素

Dataset对象元素的结构

Dataset对象是一个python可迭代对象,它表示一个元素序列,其中每个元素都可由一个或多个组件组成。组件的类型可以是tf.Tensor、tf.TensorArray、tf.data.Dataset类型。

元素的结构可以是元组tuple、或者是字典dict。不可以是列表

如元组结构的元素:(tf.Tensor, tf.Tensor)或者(Dataset, Dataset);例如字典结构的元素:{‘sku_no’: tf.Tensor}

Dataset.element_spec 属性允许您检查每个元素组件的类型

演示例子,构建Dataset对象,并获取其元素

你可以从下面不同的演示例子中寻找到自己想要的方法

使用NumPy数组

如果训练的所有输入数据都可以装入内存中,那么根据这些数据构建Dataset的最简单方式就是将数据装换成tf.Tensor对象之后,并使用tf.data.Dataset.from_tensor_slices进行构建

# 生成模拟数据
img = np.random.randint(0, 255, size=(1000, 28, 28, 1))  # 生成1000张图片大小28*28尺寸的灰白图片
lable = np.random.randint(0, 9, size=(1000, 1))  # 生成这1000张图片对应的类别

# 将数据类型转成tf.Tensor
img = tf.constant(img,dtype=tf.float32)
lable = tf.constant(lable,dtype=tf.float32)

# 步骤一:构建Dataset对象
dataset = tf.data.Dataset.from_tensor_slices((img, lable))

# 查看Dataset对象中每个元素的结构及其组件类型
print(dataset.element_spec)

# 步骤二:通过Dataset对象的方法将其转换成一个新的Dataset对象。这里不做任何转换。后续会有详细解释
pass

# 步骤三:通过for循环获取Dataset对象中的元素
for img, label in dataset:
    print(img.shape, label.shape)
    print(type(img), type(label))
    break
(TensorSpec(shape=(28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(1,), dtype=tf.float32, name=None))
(28, 28, 1) (1,)
<class 'tensorflow.python.framework.ops.EagerTensor'> <class 'tensorflow.python.framework.ops.EagerTensor'>

from_tensor_slices(tensors)函数的作用是

创造一个元素是tensors的片段的Dataset对象。片段的形成通过切分tensors的第一个维度

使用python生成器(图片生成器)

通过Dataset.from_generator可以将python生成器转换成具有完整功能的Dataset对象

# 构建python生成器,生成器生成一张图片和该图片对应的label
def gen_img_lable():  
  i = 0
  while True:
    lable = np.random.randint(0, 9, size=(1, ))  # 生成图片对应的类别
    img = np.random.randint(0, 255, size=(28, 28, 1))  # 生成一张图片大小28*28通道数为1的灰白图片
    yield img, lable
    i += 1

# 步骤一:根据from_generator将Python生成器转换成Dataset对象
dataset = tf.data.Dataset.from_generator(
    generator=gen_img_lable,  # 传入生成器的可调用函数。gen_img_lable()称之为生成器,但是gen_img_lable称之为可调用函数
    output_shapes=((28, 28, 1),(1,)),  # 输出元素的形状.如果某个维度形状大小不确定,用None代替
    output_types=(tf.float32,tf.float32),  # 输出元素的类型
)

# 步骤二:应用Dataset对象的方法将Dataset对象转换成新的Dataset对象
dataset = dataset.batch(4)  # 将原先4个元素堆叠成一个元素,即批量输出

# 步骤三:通过for循环获取Dataset对象中的元素
for img, lable in dataset:
  print(img.shape, label.shape)
  break

使用TFRecord数据

使用文本数据

使用CSV数据

CSV可以存储结构化文本形式的数据

如果数据可以存储在内存中,可以选用Dataset.from_tensor_slices方法。通过将数据帧df转成字典,再输入from_tensor_slices即可

# 模拟两款商品的每天的销售件数
df = pd.DataFrame({
    'sku_no':['A001','A001','A001','A002','A002','A002',],
    'catetory':['shoes','shoes','shoes','T-shirt','T-shirt','T-shirt',],
    'date':['2022-10-01','2022-10-02','2022-10-03','2022-09-01','2022-09-02','2022-09-03',],
    'sellqty':[12,15,17,2,5,7]
})

if True: # 数据可以存储在内存中
    # 步骤一:生成dataset对象
    dataset = tf.data.Dataset.from_tensor_slices(dict(df))
    # 步骤二:应用合适的Dataset对象的方法转换成新的Dataset对象
    dataset = dataset.batch(2)  # 相当于一次性拿两行
    # 步骤三:使用for循环使用每个元素
    for element in dataset:
        print(element,type(element))
        for key, value in element.items():
            print(value)
        print()
{'sku_no': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'A001', b'A001'], dtype=object)>, 'catetory': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'shoes', b'shoes'], dtype=object)>, 'date': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'2022-10-01', b'2022-10-02'], dtype=object)>, 'sellqty': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 15])>} <class 'dict'>
tf.Tensor([b'A001' b'A001'], shape=(2,), dtype=string)
tf.Tensor([b'shoes' b'shoes'], shape=(2,), dtype=string)
tf.Tensor([b'2022-10-01' b'2022-10-02'], shape=(2,), dtype=string)
tf.Tensor([12 15], shape=(2,), dtype=int64)

{'sku_no': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'A001', b'A002'], dtype=object)>, 'catetory': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'shoes', b'T-shirt'], dtype=object)>, 'date': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'2022-10-03', b'2022-09-01'], dtype=object)>, 'sellqty': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([17,  2])>} <class 'dict'>
tf.Tensor([b'A001' b'A002'], shape=(2,), dtype=string)
tf.Tensor([b'shoes' b'T-shirt'], shape=(2,), dtype=string)
tf.Tensor([b'2022-10-03' b'2022-09-01'], shape=(2,), dtype=string)
tf.Tensor([17  2], shape=(2,), dtype=int64)

{'sku_no': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'A002', b'A002'], dtype=object)>, 'catetory': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'T-shirt', b'T-shirt'], dtype=object)>, 'date': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'2022-09-02', b'2022-09-03'], dtype=object)>, 'sellqty': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([5, 7])>} <class 'dict'>
tf.Tensor([b'A002' b'A002'], shape=(2,), dtype=string)
tf.Tensor([b'T-shirt' b'T-shirt'], shape=(2,), dtype=string)
tf.Tensor([b'2022-09-02' b'2022-09-03'], shape=(2,), dtype=string)
tf.Tensor([5 7], shape=(2,), dtype=int64)

如果想直接从csv文件中读取数据,可以使用tf.data.experimental.make_csv_dataset函数

# 模拟数据
df = pd.DataFrame({
    'sku_no':['A001','A001','A001','A002','A002','A002',],
    'catetory':['shoes','shoes','shoes','T-shirt','T-shirt','T-shirt',],
    'date':['2022-10-01','2022-10-02','2022-10-03','2022-09-01','2022-09-02','2022-09-03',],
    'sellqty':[12,15,17,2,5,7]
})
# 将数据保存成文件
df.to_csv('experimental_make_csv_dataset_file.csv')

# 步骤一:创建Dataset对象
dataset = tf.data.experimental.make_csv_dataset(
    './experimental_make_csv_dataset_file.csv',  # 指定文件
    batch_size=2,  # 指定批量大小
    # label_name="sellqty",  # 指定标签列
    select_columns=['sku_no', 'catetory']  # 选择哪些列作为特征
)

# 步骤二:应用Dataset对象的方法将其转换成新的Dataset对象
pass

# 步骤三:通过for循环获取Dataset对象的元素
for element in dataset:
    for key, value in element.items():
        print(key, value)
    break

使用文件集

演示例子,常用的Dataset对象的方法

简单批处理方法.batch()

最简单的批处理方式是将数据集的 n 个连续元素堆叠成单个元素。Dataset.batch() 转换就负责执行此操作,它有和 tf.stack() 算子相同的约束,应用于元素的每个组件:也就是说,对于每个组件 i,所有元素都必须有一个形状完全相同的张量。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)  # 如果要忽略最后一个批次,以获得完整的形状传播,用drop_remainder=True
# batched_dataset = dataset.batch(4,drop_remainder=True)  # 如果要忽略最后一个批次,以获得完整的形状传播,用drop_remainder=True
print(batched_dataset)
for batch in batched_dataset:
  print(batch)
  print()
  break
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.int64)>
(<tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 1, 2, 3])>, <tf.Tensor: shape=(4,), dtype=int64, numpy=array([ 0, -1, -2, -3])>)

如果最后一个批次大小不完整,形状会填充为None

如果要忽略最后一个批次,以获得完整的形状传播,用drop_remainder=True

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
# batched_dataset = dataset.batch(4)  # 如果要忽略最后一个批次,以获得完整的形状传播,用drop_remainder=True
batched_dataset = dataset.batch(4,drop_remainder=True)  # 如果要忽略最后一个批次,以获得完整的形状传播,用drop_remainder=True
print(batched_dataset)
for batch in batched_dataset:
  print(batch)
  print()
  break
<BatchDataset shapes: ((4,), (4,)), types: (tf.int64, tf.int64)>
(<tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 1, 2, 3])>, <tf.Tensor: shape=(4,), dtype=int64, numpy=array([ 0, -1, -2, -3])>)

epochs重复利用数据集的方法.repeat()

下面这种代码

epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)

效果相当于

titanic_batches = dataset.batch(128,drop_remainder=True).repeat(3)

就是说假设整个数据集的元素个数是1200个,那么每次取128个元素,取了9次之后,剩下的8个元素形成不了一个批次,舍弃掉。然后重复这样的过程三次

随机重排输入数据.shuffle()

预处理数据.map()

Dataset.map(f) 转换会通过对输入数据集的每个元素应用一个给定函数 f 来生成一个新的数据集。它基于 map() 函数,该函数通常应用于函数式编程语言中的列表(和其他结构)。函数 f 会获取在输入中表示单个元素的 tf.Tensor 对象,并返回在新数据集中表示单个元素的 tf.Tensor 对象。它的实现使用标准的 TensorFlow 运算来将一个元素转换为另一个元素。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值