sklearn训练模型、保存模型文件(文本、pkl)、模型文件转换(pkl2onnx)以及模型可视化

该博客介绍了使用Python进行模型操作的相关内容。使用Jupyter Lab和Python2 kernel,借助GraphViz和Netron进行模型可视化,在onnx容器中进行模型转化。包含创建训练模型、保存为pdf、用joblib和pickle保存模型并预测,还介绍了pkl转onnx格式及用Netron查看模型。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.使用环境

IDE:Jupyter Lab,使用Python2 kernel实现

模型可视化:GraphViz,可以直接在jupyter中使用;Netron    window版本

模型转化:在onnx/onnx-ecosystem容器中进行

2.代码

创建并训练模型

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd


from sklearn.datasets import load_iris
from sklearn import tree

iris = load_iris()

# 训练模型
clf =  tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
with open("iris.dot", 'w') as f:
    f = tree.export_graphviz(clf, out_file=f)


from IPython.display import Image  
import pydotplus

dot_data = tree.export_graphviz(clf, out_file=None, 
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)


graph = pydotplus.graph_from_dot_data(dot_data)

# 模型可视化
Image(graph.create_png())


 

将图片保存为pdf

#设置环境变量,解决调用graph时“InvocationException: GraphViz's executables not found”的错误。

import os
os.environ["PATH"] += os.pathsep + 'D:/Anaconda2/Library/bin/graphviz/' 

dot_data = tree.export_graphviz(clf, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data) 
graph.write_pdf("iris.pdf")

使用joblib保存模型为pkl格式,并读取pkl格式的模型文件进行预测

from sklearn.externals import joblib
joblib.dump(clf, "DecisionTreeClassifier.pkl")

f1=joblib.load('DecisionTreeClassifier.pkl')

f1.score(iris.data, iris.target)


使用pickle保存模型为文本格式并读取通过pickle保存的模型文件进行预测

import pickle
s=pickle.dumps(clf)
f=open('DecisionTreeClassifier.txt','w')
f.write(s)
f.close()

f2=open('DecisionTreeClassifier.txt','r')
s2=f2.read()
clf2=pickle.loads(s2)
clf2.score(iris.data, iris.target)

模型格式转换

在onnx/onnx-ecosystem容器执行如下代码:

将pkl格式的模型文件转换为onnx:DecisionTreeClassifier.pkl  ----> model.onnx

from sklearn.externals import joblib
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import *
import onnxmltools

# Update the input name and path for your sklearn model
input_skl_model = 'DecisionTreeClassifier.pkl'

# input data type for your sklearn model
input_data_type = [('float_input', FloatTensorType([1, 4]))]

# Change this path to the output name and path for the ONNX model
output_onnx_model = 'model.onnx'

# Load your sklearn model
skl_model = joblib.load(input_skl_model)

# Convert the sklearn model into ONNX
onnx_model = onnxmltools.convert_sklearn(skl_model, initial_types=input_data_type)

# Save as protobuf
onnxmltools.utils.save_model(onnx_model, output_onnx_model)

3.使用Netron查看pkl模型和onnx模型

查看pkl格式的模型

 

查看onnx格式的模型

 

<think>好的,用户的问题是如何使用sklearn模型保存pkl文件。我需要参考提供的引用资料来找到正确的答案。首先,引用[2]和[4]提到了使用joblib来保存模型pkl文件。引用[4]中的代码示例显示,可以使用joblib.dump方法,并指定文件名后缀为.pkl。另外,引用[3]和[4]也提到了pickle模块的方法,但用户特别询问的是pkl格式,而joblib通常更推荐用于sklearn模型,尤其是当模型包含大型numpy数组时。需要注意的是,引用[4]中指出,新版sklearn中joblib需要直接导入,而不是从sklearn.externals引入,这可能影响代码的正确性。因此,正确的步骤应该是:训练模型后,导入joblib,使用dump函数保存模型,指定文件名,如'clf.pkl'。同时,compress参数可以用于压缩,节省空间。需要提醒用户关于跨平台兼容性的问题,如32位和64位系统的差异,引用[4]也提到了这一点。最后,确保示例代码正确,并对比两种方法,但用户主要关注的是joblib的方法,所以重点放在那里。需要检查是否有过时的导入方式,并给出最新版本的代码示例。</think>在scikit-learn中保存模型为.pkl文件有两种主流方法,这里推荐使用joblib方式,以下是具体操作步骤和代码示例: ### 方法一:使用joblib保存(推荐) ```python from sklearn.svm import SVC from sklearn.datasets import load_iris from joblib import dump, load # 训练模型 X, y = load_iris(return_X_y=True) clf = SVC().fit(X, y) # 保存模型到当前目录 dump(clf, 'model.pkl') # 加载模型 loaded_clf = load('model.pkl') # 使用模型预测 print(loaded_clf.predict(X[0:1])) # 输出:[0] ``` ### 方法二:使用pickle保存 ```python import pickle from sklearn.svm import SVC # 训练模型 clf = SVC().fit(X, y) # 保存模型 with open('model.pkl', 'wb') as f: pickle.dump(clf, f) # 加载模型 with open('model.pkl', 'rb') as f: loaded_clf = pickle.load(f) ``` ### 关键差异说明 1. **存储效率**:joblib对含大数组的模型存储效率更高(如神经网络),pickle适合小型模型[^4] 2. **压缩支持**:joblib可通过`compress=3`参数启用压缩(0-9级别),节省50%+存储空间[^4] 3. **版本兼容**:需确保保存/加载时使用相同Python版本(32位/64位不兼容)[^4] ### 最佳实践建议 1. 使用`dump()`时建议指定版本号: ```python dump(clf, 'model_v1.0.0.pkl') ``` 2. 保存时记录训练环境: ```python import platform print(f"Saved with sklearn {sklearn.__version__} on {platform.platform()}") ```
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

汀桦坞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值