TensorFlow模型训练与评估实战:使用tf.estimator.train_and_evaluate()

TensorFlow模型训练与评估实战:使用tf.estimator.train_and_evaluate()

training-data-analyst Labs and demos for courses for GCP Training (http://cloud.google.com/training). training-data-analyst 项目地址: https://gitcode.com/gh_mirrors/tr/training-data-analyst

本文是Google Cloud Platform培训数据分析师项目中关于TensorFlow模型训练与评估的实战教程。我们将深入探讨如何使用tf.estimator.train_and_evaluate()方法进行模型训练和评估,以及如何为生产环境准备模型。

核心概念与技术要点

1. 训练与评估输入函数

在TensorFlow中,我们首先需要定义数据输入管道。本教程中使用了两种输入函数:

  • train_input_fn():用于训练数据,包含数据打乱(shuffle)和重复(repeat)操作
  • eval_input_fn():用于评估数据,仅进行简单的批次处理
def train_input_fn(csv_path, batch_size=128):
    dataset = read_dataset(csv_path)
    dataset = dataset.shuffle(buffer_size=1000).repeat(count=None).batch(batch_size=batch_size)
    return dataset

def eval_input_fn(csv_path, batch_size=128):
    dataset = read_dataset(csv_path)
    dataset = dataset.batch(batch_size=batch_size)
    return dataset

2. 服务输入接收函数(serving_input_receiver_fn)

在生产环境中,我们通常需要将模型部署为服务,供远程客户端通过REST API访问。为此,我们需要定义serving_input_receiver_fn()函数:

def serving_input_receiver_fn():
    receiver_tensors = {
        'dayofweek': tf.placeholder(dtype=tf.int32, shape=[None]),
        'hourofday': tf.placeholder(dtype=tf.int32, shape=[None]),
        # 其他特征...
    }
    features = receiver_tensors
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

这个函数有两个主要作用:

  1. 定义TensorFlow Serving接收的输入张量的占位符
  2. 添加任何必要的操作,将客户端数据转换为模型期望的格式

3. 训练与评估配置

传统方法是在训练完成后才进行评估,这无法及时发现过拟合。train_and_evaluate()方法允许我们在训练过程中定期评估模型:

config = tf.estimator.RunConfig(
    model_dir=OUTDIR,
    tf_random_seed=1,
    save_checkpoints_steps=100  # 每100步保存一次检查点
)

model = tf.estimator.DNNRegressor(
    hidden_units=[10,10],
    feature_columns=feature_cols,
    config=config
)

4. 添加自定义评估指标

默认的DNNRegressor只提供MSE指标,我们可以通过tf.contrib.estimator.add_metrics()添加RMSE指标:

def my_rmse(labels, predictions):
    pred_values = tf.squeeze(input=predictions["predictions"], axis=-1)
    return {
        "rmse": tf.metrics.root_mean_squared_error(labels=labels, predictions=pred_values)
    }

model = tf.contrib.estimator.add_metrics(estimator=model, metric_fn=my_rmse)

5. 训练与评估规范

我们需要定义训练和评估规范:

train_spec = tf.estimator.TrainSpec(
    input_fn=lambda: train_input_fn("./taxi-train.csv"),
    max_steps=500
)

eval_spec = tf.estimator.EvalSpec(
    input_fn=lambda: eval_input_fn("./taxi-valid.csv"),
    steps=None,
    start_delay_secs=1,
    throttle_secs=1,
    exporters=exporter
)

tf.estimator.train_and_evaluate(estimator=model, 
                              train_spec=train_spec, 
                              eval_spec=eval_spec)

6. TensorBoard监控

TensorBoard是可视化训练过程的重要工具:

# 启动TensorBoard
get_ipython().system_raw(
    "tensorboard --logdir {} --host 0.0.0.0 --port 6006 &"
    .format(OUTDIR)

生产环境准备

训练完成后,模型会被导出为SavedModel格式,这是TensorFlow Serving的标准格式。导出的模型位于taxi_trained/export目录下,可以部署到生产环境中。

最佳实践

  1. 定期评估:设置合理的save_checkpoints_steps,确保及时发现过拟合
  2. 自定义指标:添加业务相关的评估指标,如本教程中的RMSE
  3. 生产准备:始终定义serving_input_receiver_fn,即使暂时不需要部署
  4. 监控:使用TensorBoard实时监控训练过程
  5. 资源管理:训练完成后及时关闭TensorBoard和ngrok进程

通过本教程,您已经掌握了使用TensorFlow Estimator API进行模型训练、评估和生产准备的全流程。这些技术可以应用于各种机器学习任务,帮助您构建更健壮、更易于部署的模型。

training-data-analyst Labs and demos for courses for GCP Training (http://cloud.google.com/training). training-data-analyst 项目地址: https://gitcode.com/gh_mirrors/tr/training-data-analyst

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

管旭韶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值