《Estimator工程实现》系列三: 模型导出示例

本文详述了如何使用TensorFlow Estimator进行模型训练后的导出流程,包括serving_input_receiver_fn()与export_outputs的定义及应用,最终实现模型的pb文件输出,便于后续服务部署。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本人原文简书博客地址为: https://www.jianshu.com/p/72058da4d7f7

使用场景

利用estimator进行完模型训练后,我们需要将模型导出成pb文件,以便日后服务部署。在部署serving之前,我们需要先将模型导出成 standard SavedModel format

利用estimator 进行模型导出,只需要指明模型的输入和输出,然后导出即可。主要涉及到的方法为:serving_input_receiver_fn(), export_outputs, tf.estimator.Estimator.export_savedmodel

博客中所构建代码,之后会上传至github中,方便大家使用estimator 进行模型导出服务。

1. serving_input_receiver_fn()官方API描述

在训练期间,input_fn()摄取数据并准备供模型使用。在服务时,类似地, serving_input_receiver_fn()接受推理请求并为模型准备它们。此功能具有以下用途:

  • 要向图表添加占位符,服务系统将使用推理请求进行提供。
  • 添加将输入格式的数据转换Tensor为模型预期的功能所需的任何其他操作。

该函数返回一个tf.estimator.export.ServingInputReceiver对象,该对象将占位符和结果特征打包Tensor在一起。

2. serving_input_receiver_fn()实现示例

tensorflow 提供了好几种serving_input_fn()。 在这里我使tf.estimator.export.ServingInputReceiver进行构建输入方法。
在代码中,features指代传入到model_fn()feature dict。也就是传入的特征字典。其中的images对应本人使用的模型中对应的images tensor。而receiver_tensors 指代我们要传入的数据字典。中间可以添加功能块,对传入的数据tensor进行处理,转换成model_fn中需要的feature

def raw_serving_input_fn():
    serialized_tf_example = tf.placeholder(tf.float32, shape=[None, FLAGS.train_image_size,FLAGS.train_image_size,3], name="images")
    features = {"images": serialized_tf_example}
    receiver_tensors = {'predictor_inputs': serialized_tf_example}
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

3. export_outputs官方API描述

编写自定义时model_fn,必须填充返回值的export_outputs元素tf.estimator.EstimatorSpec。这是 {name: output}描述在服务期间要导出和使用的输出签名的单词。

在进行单个预测的通常情况下,此dict包含一个元素,并且它name是无关紧要的。在多头模型中,每个头部由该词典中的条目表示。在这种情况下name,您可以选择一个字符串,用于在服务时请求特定的头部。

每个output值必须是一个ExportOutput对象,例如 tf.estimator.export.ClassificationOutputtf.estimator.export.RegressionOutput,或 tf.estimator.export.PredictOutput

4. export_outputs实现示例

订好好输出的字典后,我们需要将该输出字典传入到tf.estimator.EstimatorSpecexport_outputs,以确定该estimator的导出模型的输出内容。

    predictions = {'pred_x_ratio': pred_x_ratio, 'pred_y_ratio': pred_y_ratio, 'pred_v': visiable_pre_argmax}

    output = {'serving_default': tf.estimator.export.PredictOutput(predictions)}
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                              mode=mode,
                              predictions=predictions,
                              loss=None, train_op=None,
        export_outputs=output )


5. 执行模型导出

通过export_savedmodel导出模型即可。

estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn,
                            strip_default_attrs=True)

6. 检查输出的模型结构

进入导出的模型文件夹路径,通过官方提供的指令查看图输出结果

saved_model_cli show --dir . --all
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | \
  sudo apt-key add -

本人模型的输出图结果如下,对应着以上的输入输出定义结果。

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['predictor_inputs'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 384, 384, 3)
        name: images:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['pred_v'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 24)
        name: Identity_2:0
    outputs['pred_x_ratio'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 24)
        name: Reshape_1:0
    outputs['pred_y_ratio'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 24)
        name: Reshape_2:0
  Method name is: tensorflow/serving/predict

参考文档:

https://tensorflow.google.cn/versions/r1.9/guide/saved_model?hl=en#using_savedmodel_with_estimators

<think>我们正在处理一个关于随机森林分类器(RandomForestClassifier)的问题,涉及可视化、使用模型进行预测以及导出模型。根据用户提供的引用,我们可以从以下几个方面展开: 1. **可视化随机森林中的单棵决策树**:引用[1][2][3]都提到了如何从随机森林中提取单棵决策树并进行可视化。 2. **使用模型进行预测**:引用[4]展示了如何训练模型并进行预测。 3. **导出训练好的模型**:虽然引用中没有直接提到,但这是常见的需求,我们可以补充。 因此,我们将按照以下步骤组织回答: - 训练随机森林模型 - 可视化单棵或多棵决策树 - 使用模型进行预测 - 导出模型(如保存为文件) 注意:可视化决策树需要安装graphviz和pydotplus。可以使用pip安装: ``` pip install graphviz pydotplus ``` 下面开始逐步实现:</think>### RandomForestClassifier 可视化、使用与导出方法 #### 1. **可视化随机森林模型** ##### a) 可视化单棵决策树 ```python from sklearn.ensemble import RandomForestClassifier from sklearn import tree import matplotlib.pyplot as plt import pydotplus from IPython.display import Image # 训练随机森林模型 rf = RandomForestClassifier(n_estimators=100, max_depth=3, random_state=42) rf.fit(X_train, y_train) # 选择森林中的第5棵树进行可视化 estimator = rf.estimators_[5] # 创建决策树图形 plt.figure(figsize=(15, 10)) tree.plot_tree(estimator, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True) plt.title("随机森林中的单棵决策树可视化") plt.show() ``` ##### b) 可视化多棵决策树(Jupyter Notebook) ```python from IPython.core.display import HTML, display for i, estimator in enumerate(rf.estimators_[:3]): # 可视化前3棵树 dot_data = tree.export_graphviz( estimator, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True ) graph = pydotplus.graph_from_dot_data(dot_data) svg = graph.create_svg().decode("utf-8") display(HTML(f"<h3>决策树 #{i+1}</h3>" + svg)) ``` #### 2. **使用模型进行预测** ```python # 预测单个样本 sample = [[5.1, 3.5, 1.4, 0.2]] # 鸢尾花样本特征 prediction = rf.predict(sample) probabilities = rf.predict_proba(sample) print(f"预测类别: {iris.target_names[prediction[0]]}") print(f"类别概率: {dict(zip(iris.target_names, probabilities[0]))}") # 预测整个测试集 y_pred = rf.predict(X_test) y_proba = rf.predict_proba(X_test) # 评估模型性能 from sklearn.metrics import accuracy_score, classification_report print(f"测试集准确率: {accuracy_score(y_test, y_pred):.2f}") print("\n分类报告:\n", classification_report(y_test, y_pred, target_names=iris.target_names)) ``` #### 3. **特征重要性可视化** ```python importances = rf.feature_importances_ features = iris.feature_names plt.figure(figsize=(10, 6)) plt.barh(features, importances, color='skyblue') plt.xlabel('特征重要性') plt.title('随机森林特征重要性分析') plt.grid(axis='x', linestyle='--', alpha=0.7) plt.show() ``` #### 4. **导出和保存模型** ##### a) 保存为文件 ```python import joblib # 保存模型 joblib.dump(rf, 'random_forest_model.joblib') # 加载模型 loaded_rf = joblib.load('random_forest_model.joblib') # 使用加载的模型预测 loaded_prediction = loaded_rf.predict(X_test) ``` ##### b) 导出决策树规则 ```python # 导出第一棵决策树的决策规则 tree_rules = tree.export_text( rf.estimators_[0], feature_names=iris.feature_names ) print("决策树规则示例:\n", tree_rules[:500]) # 打印前500个字符 ``` ##### c) 导出为可视化文件 ```python # 导出单棵决策树为PDF dot_data = tree.export_graphviz( rf.estimators_[0], out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True ) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf("decision_tree.pdf") ``` ### 关键解释 1. **可视化意义**: - 单棵决策树展示模型的具体决策逻辑[^3] - 特征重要性揭示关键影响因素[^4] 2. **模型使用**: - `predict()` 返回预测类别 - `predict_proba()` 返回类别概率分布 3. **模型导出**: - `joblib` 适合保存大型模型(如随机森林)[^1] - 导出规则可用于业务解释和系统集成 4. **随机森林特性**: - 通过聚合多棵树的预测提高准确性 - 天然处理特征间非线性关系 - 提供特征重要性评估 ### 输出示例 ``` 测试集准确率: 0.97 分类报告: precision recall f1-score support setosa 1.00 1.00 1.00 13 versicolor 1.00 0.94 0.97 16 virginica 0.90 1.00 0.95 9 accuracy 0.97 38 macro avg 0.97 0.98 0.97 38 weighted avg 0.98 0.97 0.97 38 ``` ![决策树可视化](https://via.placeholder.com/600x400?text=Decision+Tree+Visualization) ![特征重要性](https://via.placeholder.com/600x300?text=Feature+Importance) ---
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值