场景
需要在java调用python sklearn训练评估的模型,本文介绍使用pmml来实现。
生成pmml文件
#引入sklearn2pmml包
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline
#使用PMMLPipeline包裹具体评估器
clf = PMMLPipeline([("MLPClassifier", MLPClassifier(hidden_layer_sizes=(25,), random_state=1, max_iter=100, warm_start=True))])
clf.fit(value, label)
#保存模型到指定文件
sklearn2pmml(clf, "MLPClassifier.pmml", with_repr=True)
JAVA调用模型
引用java maven依赖包
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.5.15</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.5.15</version>
</dependency>
java加载模型并评估
Map<String, Object> paramData = new HashMap<>();
paramData.put("x1", 180D);
paramData.put("x2", 350D);
FileInputStream inputStream = new FileInputStream("MLPClassifier.pmml");
//解析pmml文件,实际上是用JAXB做xml的解析
PMML pmml = PMMLUtil.unmarshal(inputStream);
//生成评估器
ModelEvaluator<?> evaluate = new ModelEvaluatorBuilder(pmml).build();
//构建输入参数
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<InputField> inputFields = evaluate.getInputFields();
for (InputField inputField : inputFields) { //将参数通过模型对应的名称进行添加
FieldName inputFieldName = inputField.getName(); //获取模型中的参数名
Object paramValue = paramData.get(inputFieldName.getValue()); //获取模型参数名对应的参数值
FieldValue fieldValue = inputField.prepare(paramValue); //将参数值填入模型中的参数中
arguments.put(inputFieldName, fieldValue); //存放在map列表中
}
//开始评估
Map<FieldName, ?> target = evaluate.evaluate(arguments);
//获取评估结果
List<TargetField> targetFields = evaluate.getTargetFields();
Object targetFieldValue = target.get(targetFields.get(0).getFieldName());
System.out.println("targetFieldValue: " + targetFieldValue);
System.out.println("target: " + target);
注意事项
1.注意生成模型的版本和java依赖包的版本要匹配,否则java侧会无法解析该pmml模型

本文介绍如何使用pmml实现Java调用Python sklearn训练的模型。首先利用sklearn2pmml生成PMML文件,然后Java应用加载此文件进行预测。需确保Python模型版本与Java依赖包版本一致。
1138





