学习TF2.0时的一些总结,参考资官方文档https://www.tensorflow.org/guide/data#top_of_page
首先加载相关库
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
tf.data.Dataset
tf.data中提供了一个tf.data.Dataset()抽象,它可以表示一系列元素(图像和对应的标签)。Dataset的创建必须从数据开始.
dataset = tf.data.Dataset.from_tensor_slices([9, 3, 1, 5, 3, 7, 5, 0])
for elem in dataset: # 使用for循环
print(elem.numpy())
it = iter(dataset) # 使用iterator对象
print(next(it).numpy())
Dataset对象的element_spec属性可以查看每个元素的类型,返回一个tf.TypeSpec对象
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))
print(dataset1.element_spec)
Dataset对象支持任何结构,可以使用Dataset.map()和Dataset.filter()对每个元素进行操作
for z in dataset1:
print(z.numpy())
数据生成器
使用python的生成器作为数据源,Dataset.from_generator()将python生成器转化为tf.data.Dataset对象
def count(stop):
i = 0
while i < stop:
yield i
i += 1
for n in count(5):
print(n)
from_generator()有三个参数需要注意:args是需要传给函数的参数,output_types是创建tf.Graph时需要的参数,output_shapes是返回的数据大小
ds_counter = tf.data.Dataset.from_generator(
count, args=[25], output_types=tf.int32)
for count_batch in ds_counter.repeat().batch(7).take(5): # 每个batch取7个,总共取5个batch
print(count_batch.numpy())
一般情况下,最好将output_types和output_shape明确指定
def gen_series():
i = 0
while True:
size = np.random.randint(0, 10)
yield i, np.random.normal(size=(size,))
i += 1
for i, series in gen_series():
print(i, ":", series)
if i > 5:
break
ds_series = tf.data.Dataset.from_generator(
gen_series,
output_types=(tf.int32, tf.float32),
output_shapes=((), (None,))) # 明确指定outpu_types和output_shapes
print(ds_series)
当对一个变长的数据进行batch时,可以使用Dataset.padded_batch
ds_series_batch = ds_series.shuffle(
20).padded_batch(10, padded_shapes=([], [None]))
ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
加载图像数据
在处理图像数据时,将preprocessing.image.ImageDataGenerator封装为一个tf.data.Dataset
flowers = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True
)
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers)) # 使用iteration方式遍历
ds = tf.data.Dataset.from_generator(
img_gen.flow_from_directory,
args=[flowers],
output_types=(tf.float32, tf.float32)
output_shapes=([32, 256, 256, 3], [32, 5])
)
加载CSV数据
也可以将CSV类型的数据直接通过from_tensor_slices()生成tf.data.Dataset
titanic_file = tf.keras.utils.get_file(
"train.csv",
"https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df)) # 通过字典的形式加载数据
for feature_batch in titanic_slices.take(1):
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
tf.data有一种更灵活的数据加载方式,experimental.make_csv_dataset()支持列级别的查询,batching和shuffling。
titanic_batchs = tf.data.experimental.make_csv_dataset(
titanic_file,
batch_size=4,
label_name='survived'
)
for feature_batch, label_batch in titanic_batchs.take(1):
print("'survived' : {}".format(label_batch))
print("features:")
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
同时可以使用select_columns参数选择特征的子集
titanic_batchs = tf.data.experimental.make_csv_dataset(
titanic_file,
batch_size=4,
label_name='survived',
# select_columns里一定要包含label_name
select_columns=['class', 'fare', 'survived']
)
for feature_batch, label_batch in titanic_batchs.take(1):
print("'survived: {}".format(label_batch))
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
还有一个底层的类experimental.CsvDataset提供更精细的控制,这种方法不支持列的查询,同时必须指定每一列的数据类型
titanic_types = [tf.int32, tf.string, tf.float32, tf.int32,
tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(
titanic_file, titanic_types, header=True)
for line in dataset.take(10):
print([item.numpy() for item in line])
Batching
Dataset.batch()可以将n个连续的元素堆叠到一个元素中,tf.stack()的功能与该函数相同
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4) # 每个batch取4个case
for batch in batched_dataset.take(5): # 取5个batch
print([arr.numpy() for arr in batch])
Dataset.batch()可能会造成不明确的batch大小,因为最后一个batch可能不足,drop_remainder可以忽略掉最后一个batch
batched_dataset = dataset.batch(7, drop_remainder=True)
print(batched_dataset)
针对大多数输入大小可变的模型,Dataset.padded_batch()可以对不同维度的大小进行改变
dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x)) # tf.fill()将x扩展到第一个参数的维度,x必须是标量
dataset = dataset.padded_batch(4, padded_shapes=(None,), drop_remainder=True)
for batch in dataset.repeat().take(5):
print(batch.numpy())
print()
本文介绍了TensorFlow 2.0中数据加载的方法,包括使用tf.data.Dataset从各种来源加载数据,如张量切片、生成器、图像、CSV文件等,并详细讲解了Dataset对象的属性和操作,如map、filter、batch等。
3252

被折叠的 条评论
为什么被折叠?



