很多情况下,线上一般使用java,然后训练机器学习模型一般都是python,那么就存在一个问题,python训练的模型,
java怎么去调用进行线上预测呢,下面以随机森林算法来介绍一下这个过程
python脚本如下:
#!/usr/bin/python
# -*- coding:utf-8 -*-
from sklearn import tree
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml
from sklearn2pmml import make_pmml_pipeline # 转换pkl文件为pmml_pipeline格式
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files/Java/jdk1.8.0_171/bin'
X=[[1,2,3,1],[2,4,1,5],[7,8,3,6],[4,8,4,7],[2,5,6,9]]
y=[0,1,0,2,1]
# PMMLPipeline只是处理estimator不能处理transformer
# 注意:有时候需要自定义函数加入到PMMLPipeline中,可以参考博客https://blog.youkuaiyun.com/weixin_38569817/article/details/87810658
pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier(random_state=9))])
pipeline.fit(X,y)
sklearn2pmml(pipeline, ".\demo.pmml", with_repr = True)
java依赖添加
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.1</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.1</version>
</dependency>
java调用脚本如下:
package xxx;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class PMMLDemo {
private Evaluator loadPmml(){
PMML pmml = new PMML();
InputStream inputStream = null;
try {
inputStream = new FileInputStream("demo.pmml"); # 需要引用python保存下来的pmml文件
} catch (IOException e) {
e.printStackTrace();
}
if(inputStream == null){
return null;
}
InputStream is = inputStream;
try {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
} catch (SAXException e1) {
e1.printStackTrace();
} catch (JAXBException e1) {
e1.printStackTrace();
}finally {
//关闭输入流
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
pmml = null;
return evaluator;
}
private int predict(Evaluator evaluator,int a, int b, int c, int d) {
Map<String, Integer> data = new HashMap<String, Integer>();
data.put("x1", a);
data.put("x2", b);
data.put("x3", c);
data.put("x4", d);
List<InputField> inputFields = evaluator.getInputFields();
//过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
int primitiveValue = -1;
if (targetFieldValue instanceof Computable) {
Computable computable = (Computable) targetFieldValue;
primitiveValue = (Integer)computable.getResult();
}
System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
return primitiveValue;
}
public static void main(String args[]){
PMMLDemo demo = new PMMLDemo();
Evaluator model = demo.loadPmml();
demo.predict(model,1,8,99,1); # 填充数据即可预测
demo.predict(model,111,89,9,11);
}
}
python训练模型,java预测模型(sklearn2pmml)
最新推荐文章于 2025-03-02 22:31:54 发布