1.预创建的Estimator
- Estimator 是 Tensorflow 完整模型的高级表示,它被设计用于轻松扩展和异步训练。
- 在 Tensorflow 2.0 中,Keras API 可以完成许多相同的任务,而且被认为是一个更易学习的API。
- Tensorflow提供了一组tf.estimator(例如,LinearRegressor)来实现常见的机器学习算法。
为了编写基于预创建的 Estimator 的 Tensorflow 项目,您必须完成以下工作:
创建一个或多个输入函数
定义模型的特征列
实例化一个 Estimator,指定特征列和各种超参数。
在 Estimator 对象上调用一个或多个方法,传递合适的输入函数以作为数据源。
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
my_feature_columns = []
for key in train.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
实例化 Estimator
鸢尾花为题是一个经典的分类问题。幸运的是,Tensorflow 提供了几个预创建的 Estimator 分类器,其中包括:
1. tf.estimator.DNNClassifier 用于多类别分类的深度模型
2. tf.estimator.DNNLinearCombinedClassifier 用于广度与深度模型
3. tf.estimator.LinearClassifier 用于基于线性模型的分类器
对于鸢尾花问题,tf.estimator.DNNClassifier 似乎是最好的选择。您可以这样实例化该 Estimator:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[30, 10],
n_classes=3)
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
评估经过训练的模型
现在模型已经经过训练,您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率(accuracy)进行评估:
eval_result = classifier.evaluate(
input_fn=lambda: input_fn(test, test_y, training=False))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
利用经过训练的模型进行预测(推理)
我们已经有一个经过训练的模型,可以生成准确的评估结果。我们现在可以使用经过训练的模型,根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样,我们使用单个函数调用进行预测:
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7