2.3-tensoflow2-基础教程-Estimator

本文介绍TensorFlow Estimator的应用,包括预创建Estimator、线性模型、提升树模型的理解及训练,以及如何从Keras模型转换到Estimator模型。通过鸢尾花分类问题和泰坦尼克号数据集,详细讲解了Estimator的使用方法。

1.预创建的Estimator

  1. Estimator 是 Tensorflow 完整模型的高级表示,它被设计用于轻松扩展和异步训练。
  2. 在 Tensorflow 2.0 中,Keras API 可以完成许多相同的任务,而且被认为是一个更易学习的API。
  3. Tensorflow提供了一组tf.estimator(例如,LinearRegressor)来实现常见的机器学习算法。
为了编写基于预创建的 Estimator 的 Tensorflow 项目,您必须完成以下工作:

创建一个或多个输入函数
定义模型的特征列
实例化一个 Estimator,指定特征列和各种超参数。
在 Estimator 对象上调用一个或多个方法,传递合适的输入函数以作为数据源。
def input_fn(features, labels, training=True, batch_size=256):
    """An input function for training or evaluating"""
    # 将输入转换为数据集。
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # 如果在训练模式下混淆并重复数据。
    if training:
        dataset = dataset.shuffle(1000).repeat()
    
    return dataset.batch(batch_size)
# 特征列描述了如何使用输入。
my_feature_columns = []
for key in train.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))
实例化 Estimator
鸢尾花为题是一个经典的分类问题。幸运的是,Tensorflow 提供了几个预创建的 Estimator 分类器,其中包括:

 1. tf.estimator.DNNClassifier 用于多类别分类的深度模型
 2. tf.estimator.DNNLinearCombinedClassifier 用于广度与深度模型
 3. tf.estimator.LinearClassifier 用于基于线性模型的分类器

对于鸢尾花问题,tf.estimator.DNNClassifier 似乎是最好的选择。您可以这样实例化该 Estimator:

# 构建一个拥有两个隐层,隐藏节点分别为 30 和 10 的深度神经网络。
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    # 隐层所含结点数量分别为 30 和 10.
    hidden_units=[30, 10],
    # 模型必须从三个类别中做出选择。
    n_classes=3)
# 训练模型。
classifier.train(
    input_fn=lambda: input_fn(train, train_y, training=True),
    steps=5000)
评估经过训练的模型
现在模型已经经过训练,您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率(accuracy)进行评估:

eval_result = classifier.evaluate(
    input_fn=lambda: input_fn(test, test_y, training=False))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
利用经过训练的模型进行预测(推理)
我们已经有一个经过训练的模型,可以生成准确的评估结果。我们现在可以使用经过训练的模型,根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样,我们使用单个函数调用进行预测:

# 由模型生成预测
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
   
   
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值