(三)高阶api: Estimator
流程:
- 假设存在合适的预创建的 Estimator,使用它构建模型。
- 训练模型
- 评估模型
- 可以通过构建自定义 Estimator 进一步改进模型
1)实例化相关的预创建的 Estimator。
estimator = tf.estimator.LinearClassifier(
feature_columns=[population, crime_rate, median_education],
)
可以指定优化算法,模型结构。不同的模型,创建Estimator参考api.
estimator = DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
optimizer=tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
))
2)训练
estimator.train(input_fn=my_training_set, steps=2000)
3)预测与评估
estimator.evaluate(input_fn=my_predict_set)
estimator.predict(input_fn=my_predict_set)
本文介绍如何使用TensorFlow的Estimator API进行模型训练、评估及预测。通过实例演示了LinearClassifier和DNNClassifier的使用,包括特征列定义、模型参数设置、训练集输入、模型训练与评估的具体步骤。
5213

被折叠的 条评论
为什么被折叠?



