TensorFlow中的高级API——Estimator

本文介绍TensorFlow Estimator的使用方法及优势,包括预置评估器与自定义评估器的区别,如何构建模型,以及如何利用Estimator进行训练、评估和预测。Estimator简化了模型开发流程,支持多种硬件平台,并提供了分布式训练的支持。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原文:
https://www.tensorflow.org/programmers_guide/estimators

Estimator 封装了以下部分:

  • 训练
  • 评估
  • 预测
  • 导出用于服务

评估器的优势

  • 基于Estimator的模型可以运行于本地,也可以运行在分布式环境中。更可以运行于CPU、GPU和TPU之上而无需修改代码
  • 方便了模型开发者间的分享
  • 创建模型比使用低等级API更容易
  • Estimator自身建立在tf.layers,简化定制
  • Estimator自动建立图表
  • 提供了安全的分布式训练循环来决定怎样以及何时:
    • 建立图表
    • 初始化变量
    • 开始队列
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存TensorBoard的摘要

使用Estimators编写应用程序时,必须将数据输入管道与模型分开。这种分离简化了不同数据集的实验。

预置评估器(Pre-made Estimators)

预置评估器创建和管理GraphSession对象,只需进行最少量的代码更改即可尝试不同的模型架构。例如, DNNClassifier 是一个预置的Estimator类,它通过密集的前馈神经网络来训练分类模型。

基于预置评估器的程序结构

  1. 写一个或多个数据集导入函数。例如训练集和测试集。每个数据集导入函数必须返回两个对象:

    • 一个字典,其中的键是特征名称,值是包含相应特征数据的Tensors(或SparseTensors)
    • 包含一个或多个标签的Tensor

    例如,下面的代码演示了输入函数的基本框架:

    def input_fn(dataset):
    ...  # manipulate dataset, extracting feature names and the label
    return feature_dict, label
  2. 定义特征列。每个tf.feature_column识别一个特征的名称,类型和任何输入预处理。以下片段创建了三个保存整数或浮点数据的特征列。 前两个特征列只是标识特征的名称和类型。 第三个特性列还指定了程序将调用的用于缩放原始数据的lambda:

    
    # Define three numeric feature columns.
    
    population = tf.feature_column.numeric_column('population')
    crime_rate = tf.feature_column.numeric_column('crime_rate')
    median_education =    tf.feature_column.numeric_column('median_education',
                    normalizer_fn='lambda x: x - global_education_mean')
  3. 实例化相关的预置评估器。例如:

    
    # Instantiate an estimator, passing the feature columns.
    
    estimator = tf.estimator.Estimator.LinearClassifier(
      feature_columns=[population, crime_rate, median_education],
      )
  4. 调用训练、评估、预测方法。例如,所有评估器都提供训练方法:

    
    # my_training_set is the function created in Step 1
    
    estimator.train(input_fn=my_training_set, steps=2000)

预置评估器的优点:

Best practices for determining where different parts of the computational graph should run, implementing strategies on a single machine or on a cluster.

Best practices for event (summary) writing and universally useful summaries.

自定义评估器

每个评估器的核心——无论是预置还是自定义——都是其模型函数,它是一种为训练,评估和预测构建图形的方法。当使用预置评估器时,其他人已经实现了该模型功能。当使用自定义评估器时,须自己编写模型函数。

推荐工作流程

  1. 假设存在合适的预置评估器,使用它构建第一个模型并使用其结果建立基线。
  2. 使用此预置评估器构建和测试整体流水线,包括数据的完整性和可靠性。
  3. 如果有合适的替代预置评估器可用,则运行实验以确定哪种预置评估器能够产生最佳结果。
  4. 可能需要通过构建自定义评估器来进一步改进模型。
### TensorFlow 2.x 中 Estimator 的兼容性和替代方案 在 TensorFlow 2.x 中,为了保持向后兼容性并支持从旧版代码迁移,`tf.compat.v1.estimator` 被保留下来用于继续使用基于 `estimator` 构建的模型[^1]。然而,官方更推荐采用 Keras API 来构建新项目中的训练循环和评估逻辑,因为这能更好地利用 TF2 提供的新特性和性能优化。 对于希望迁移到最新框架下的开发者来说,可以考虑以下两种方式: #### 方法一:通过 tf.compat.v1 继续使用 estimator 如果现有应用程序依赖于 `tf.estimator` 或其派生类,则可以通过导入 `tf.compat.v1` 命名空间来维持原有功能不变。这种方式允许逐步更新其他部分而不立即改变整个工作流结构。 ```python import tensorflow as tf # 使用兼容模式加载Estimator模块 from tensorflow import compat my_estimator = compat.v1.estimator.Estimator(model_fn=my_model_function) ``` 这种方法虽然简单直接,但在长期维护方面可能不如完全过渡至新的接口理想。 #### 方法二:转向 keras Model 和自定义 Training Loop 另一种更为现代的选择是重构代码以适应 Keras 高级抽象层——即创建子类化的 `Model` 对象,并实现定制化训练过程。此路径不仅简化了编码流程,还能够充分利用 GPU/TPU 加速以及分布式计算等功能特性。 ```python class MyCustomModel(tf.keras.Model): def __init__(self, ... ): super(MyCustomModel, self).__init__() ... def call(self, inputs): ... model = MyCustomModel(...) optimizer = tf.keras.optimizers.Adam() loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) @tf.function def train_step(images, labels): with tf.GradientTape() as tape: predictions = model(images, training=True) loss = loss_object(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) for epoch in range(EPOCHS): for images, labels in dataset: train_step(images, labels) ``` 上述例子展示了如何用简洁明了的方式替换传统的 `estimator` 设计思路,在此基础上还可以进一步集成回调函数(callbacks)、保存检查点(checkpoints)等实用工具[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值