sklearn-json
前言
需求:导出sklearn训练好的算法模型为json格式,方便在不同编程语言间传递数据。
方案:使用 sklearn-json
安装 sklearn-json
pip install sklearn-json
注: 需要 scikit-learn >= 0.21.3
使用
序列化模型为json
以分类决策树为例子
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import sklearn_json as skljson
# data
wine = load_wine()
# train/test split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data, wine.target, test_size=0.3)
# train with deicision tree
clf = tree.DecisionTreeClassifier(criterion='gini', max_depth=5, random_state=0)
clf = clf.fit(Xtrain, Ytrain) # after fit, clf is the model
# save model to json
skljson.to_json(clf, "tree_model") # 重点重点重点
至此,分类决策树已经存成json格式。
json是肉眼可理解的,打开"tree_model"文件,看到如下:
{
"meta":