TensorFlow Estimator 官方文档之----Dataset for Estimator


tf.data 模块包含一系列类,让您可以轻松地加载数据、操作数据并通过管道将数据传送到模型中。本文档通过两个简单的示例来介绍该 API:

  • 从 Numpy 数组中读取内存中的数据。
  • 从 csv 文件中读取行。

从 Numpy 数组中读取内存中的数据

要开始使用 tf.data,最简单的方法是从内存中的数组中提取切片。

内置 Estimator 一章介绍了 iris_data.py 中的以下 train_input_fn,它可以通过管道将数据传输到 Estimator 中:

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset

我们来详细了解一下。

参数

此函数需要三个参数。要求所赋值为“数组”的参数能够接受可通过 numpy.array 转换成数组的几乎任何值。其中存在一个例外,即对 Datasets 有特殊意义的 tuple,稍后我们会发现这一点。

  • features:包含原始输入特征的 {‘feature_name’:array} 字典(或 DataFrame)。
  • labels:包含每个样本的标签的数组。
  • batch_size:表示所需批次大小的整数。

premade_estimator.py 中,我们使用 iris_data.load_data() 函数检索了鸢尾花数据。您可以运行该函数并解压结果,如下所示:

import iris_data

# Fetch the data
train, test = iris_data.load_data()
features, labels = train

然后,我们使用类似以下内容的行将此数据传递给了输入函数:

batch_size=100
iris_data.train_input_fn(features, labels, batch_size)

下面我们详细介绍一下 train_input_fn()

切片

首先,此函数会利用 tf.data.Dataset.from_tensor_slices 函数创建一个代表数组切片的 tf.data.Dataset。系统会在第一个维度内对该数组进行切片。例如,mnist 训练数据的数组的形状为 (60000, 28, 28)。将该数组传递给 from_tensor_slices 会返回一个包含 60000 个切片的 Dataset 对象,其中每个切片都是一个 28x28 的图像。

返回此 Dataset 的代码如下所示:

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)

这段代码将输出以下行,显示数据集中条目的 shapes 和 dtypes。请注意Dataset 不知道自己包含多少条目。

<TensorSliceDataset shapes: (28,28), types: tf.uint8>

上面的 Dataset 表示一组简单的数组,但实际的数据集要比这复杂得多。Dataset 可以按照透明方式处理字典或元组(或 namedtuple)的任何嵌套组合。

例如,在将鸢尾花 features 转换为标准 Python 字典后,您可以将数组字典转换为字典 Dataset,如下所示:

dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDataset

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}
>

我们可以看到,如果 Dataset 包含结构化元素,则 Datasetshapestypes 将采用同一结构。此数据集包含所有类型为 tf.float64标量字典。

鸢尾花 train_input_fn 的第一行使用相同的功能,但添加了另一层结构。它会创建一个包含 (features_dict, label) 对的数据集。

以下代码显示标签是类型为 int64 的标量:

# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (), 
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64, 
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

操作

目前,Dataset 会按固定顺序迭代数据一次,并且一次仅生成一个元素。它需要进一步处理才可用于训练。幸运的是,tf.data.Dataset 类提供了更好地准备训练数据的方法。输入函数的下一行就利用了其中的几种方法:

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

shuffle 方法使用一个固定大小的缓冲区,在条目经过时随机化处理条目。在这种情况下,buffer_size 大于 Dataset 中样本的数量,确保数据完全被随机化处理(鸢尾花数据集仅包含 150 个样本)。

repeat 方法会在结束时重启 Dataset。要限制周期数量,请设置 count 参数。

batch 方法会收集大量样本并将它们堆叠起来以创建批次。这为批次的形状增加了一个维度。新的维度将添加为第一个维度。以下代码对之前的 MNIST Dataset 使用 batch 方法。这样会产生一个包含表示 (28,28) 图像堆叠的三维数组的 Dataset

print(mnist_ds.batch(100))
<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>

请注意,该数据集的批次大小是未知的,因为最后一个批次具有的元素数量会减少。

train_input_fn 中,经过批处理之后,Dataset 包含元素的一维向量,其中每个标量之前如下所示:

print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

返回

此时,Dataset 包含 (features_dict, labels) 对。这是 trainevaluate 方法的预期格式,因此 input_fn 会返回相应的数据集。

使用 predict 方法时,可以/应该忽略 labels

读取 CSV 文件

Dataset 类最常见的实际用例是流式传输磁盘上文件中的数据。tf.data 模块包含各种文件阅读器。我们来看看如何使用 Dataset 解析 csv 文件中的 Iris 数据集。

iris_data.maybe_download 函数的以下调用会根据需要下载数据,并返回所生成文件的路径名:

import iris_data
train_path, test_path = iris_data.maybe_download()

iris_data.csv_input_fn 函数包含使用 Dataset 解析 csv 文件的备用实现。

我们来了解一下如何构建从本地文件读取数据且兼容 Estimator 的输入函数。

构建 Dataset

我们先构建一个 TextLineDataset 对象,实现一次读取文件中的一行数据。然后,我们调用 skip 方法来跳过文件的第一行,此行包含标题,而非样本:

ds = tf.data.TextLineDataset(train_path).skip(1)

构建 csv 行解析器

我们先构建一个解析单行的函数。

以下 iris_data.parse_line 函数会使用 tf.decode_csv 函数和一些简单的 Python 代码来完成此任务:

为了生成必要的 (features, label) 对,我们必须解析数据集中的每一行。以下 _parse_line 函数会调用 tf.decode_csv,以将单行解析为特征和标签两个部分。由于 Estimator 需要将特征表示为字典,因此我们依靠 Python 的内置 dictzip 函数来构建此字典。特征名称是该字典的键。然后,我们调用字典的 pop 方法以从特征字典中移除标签字段:

# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
           'PetalLength', 'PetalWidth',
           'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS,fields))

    # Separate the label from the features
    label = features.pop('label')

    return features, label

解析行

数据集提供很多用于在通过管道将数据传送到模型的过程中处理数据的方法。最常用的方法是 map,它会对 Dataset 的每个元素应用转换。

map 方法会接受 map_func 参数,此参数描述了应该如何转换 Dataset 中的每个条目。

在这里插入图片描述

因此,为了在从 csv 文件中流式传出行时对行进行解析,我们将 _parse_line 函数传递给 map 方法:

ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
    ()),
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
    tf.int32)>

现在,数据集包含 (features, label) 对,而不是简单的标量字符串。

iris_data.csv_input_fn 函数的剩余部分与 iris_data.train_input_fn 函数完全相同,后者在基本输入部分中进行了介绍。

试试看

此函数可用于替换 iris_data.train_input_fn。可使用此函数馈送 Estimator,如下所示:

train_path, test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    tf.feature_column.numeric_column(name)
    for name in iris_data.CSV_COLUMN_NAMES[: -1]
]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
    n_classes = 3)# Train the estimator
batch_size = 100
est.train(
    steps = 1000,
    input_fn = lambda: iris_data.csv_input_fn(train_path, batch_size))

Estimator 要求 input_fn 不接受任何参数。为了不受此限制约束,我们使用 lambda 来获取参数并提供所需的接口。

总结

tf.data 模块提供一系列类和函数,可用于轻松从各种来源读取数据。此外,tf.data 还提供简单而又强大的方法,用于应用各种标准和自定义转换。

现在,您已经基本了解了如何高效地将数据加载到 Estimator 中。接下来,请查看下列文档:

<think>好的,用户想了解TensorFlow Estimator在版本1.13.0中的使用方法和文档。首先,我需要确认用户的需求。他们可能在使用较旧的TensorFlow版本,需要特定于1.13.0的信息,可能是因为项目依赖或兼容性问题。 首先,我应该回忆TensorFlow Estimator的基本概念。EstimatorTensorFlow的高级API,用于简化模型训练、评估和预测。在1.x版本中,Estimator是推荐的方式,尤其是在生产环境中。版本1.13.0属于较早期的版本,发布于2019年左右,可能与当前文档有所不同。 接下来,用户需要具体的用法和文档。由于TensorFlow的文档会随着版本更新而变化,直接访问当前文档可能不适用于1.13.0。因此,我需要引导用户查看对应版本的官方文档。通常,TensorFlow官网会保留历史版本的文档,用户可以通过版本选择器切换。 然后,考虑用户可能遇到的问题,比如如何定义Estimator、训练模型、使用输入函数等。需要提供基本步骤和示例代码,但代码必须符合1.13.0的语法。例如,在1.x版本中,可能需要使用tf.estimator.DNNClassifier等预定义的Estimator,以及使用tf.data.Dataset构建输入管道。 另外,用户提到参考中的内容涉及导出检查点和使用TensorBoard监控训练进度。这部分在1.13.0中应该也是支持的,需要说明如何在代码中实现这些功能,比如在Estimator的配置中设置模型目录,以便TensorBoard读取日志。 需要注意的潜在问题包括版本兼容性,比如某些API在后续版本中可能已弃用或更改。例如,tf.contrib中的部分功能在2.0中被移除,但在1.13.0中可能还存在。因此,提醒用户注意版本差异,并建议可能的升级路径,如果可行的话。 用户可能需要示例代码,所以需要准备一个简单的例子,比如线性回归或分类模型,使用Estimator的流程。同时,强调文档的重要性,指导用户如何查找特定版本的文档,并提供链接或方法。 最后,生成相关问题时要覆盖常见的使用场景、问题排查和升级考虑,帮助用户全面了解可能遇到的挑战。</think>关于TensorFlow Estimator 1.13.0的使用,以下是关键信息整理: 1. **核心功能定位**: - EstimatorTensorFlow 1.x系列中的高阶API,主要用于封装完整的机器学习工作流,包括训练、评估、预测和模型导出[^1] - 在1.13.0版本中,支持预定义模型(如`DNNClassifier`)和自定义模型两种使用方式 2. **典型使用流程**: ```python # 定义输入函数 def input_fn(): dataset = tf.data.Dataset.from_tensor_slices((features, labels)) return dataset.shuffle(1000).batch(32).repeat() # 构建预定义Estimator classifier = tf.estimator.DNNClassifier( feature_columns=[tf.feature_column.numeric_column('x', shape=[4])], hidden_units=[10, 20], model_dir='/tmp/model' ) # 执行训练 classifier.train(input_fn=input_fn, steps=2000) ``` 3. **版本特性说明**: - 支持与`tf.data`API无缝集成构建数据管道 - 包含`tf.estimator.RunConfig`用于配置运行参数 - 提供`train_and_evaluate`方法实现训练与验证的交替执行 4. **文档获取方式**: - 访问TensorFlow 1.13.0官方文档: ``` https://www.tensorflow.org/versions/r1.13/api_docs/python/tf/estimator/Estimator ``` - 使用本地文档查看(需安装对应版本): ```bash python -m pydoc tensorflow.estimator.DNNClassifier ``` 5. **注意事项**: - 该版本默认使用静态计算图模式 - 模型导出需使用`tf.estimator.export`模块 - TensorBoard日志默认保存在`model_dir`指定路径
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值