前段时间做的xgboost要部署成service,中途碰到不少坑,首先你部署的机器必须先装gcc环境,而且要6.0x版本
以上,4.x的版本的并不适应,看下代码service项目吧
首先是接口:
package com.dianping.text.classify.api.service;
public interface PredictionService {
public double prediction(String text);
public double prediction(String text,String cata);
public double prediction(String text,double rate);
public double prediction(String text,double rate,int textlength);
public double prediction(String text,String cata,double rate);
public double prediction(String text,String cataS,double rate,int textlength);
}
实现接口类:
package com.dianping.text.classify.serviceimpl;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.alibaba.fastjson.JSONObject;
import com.dianping.text.classify.api.service.PredictionService;
import com.dianping.text.classify.util.DataLoader;
import com.dianping.text.classify.util.PathUtil;
import com.dianping.text.classify.util.Terms;
import com.meituan.nlp.util.WordUtil;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ServiceImpl implements PredictionService {
private static final Logger logger = LoggerFactory.getLogger(ServiceImpl.class);
private Booster booster;
private Map<String, Terms> mapTerms ;
private void init(){
loadmapterms(this.getClass().getResource("/").getPath()+"model/mapterms.model");
loadXgbosost(this.getClass().getResource("/").getPath()+"model/xgboost.model");
}
@Override
public double predictionClassify(String text) {
double result = DataLoader.getClassification(text, mapTerms, booster);
return result;
}
@Override
public List<String> segmentation(String text) {
if (StringUtils.isBlank(text)) {
return null;
}
List<String> list = new ArrayList<String>();
List<org.ansj.domain.Term> lists = ToAnalysis.parse(text).getTerms();
for (org.ansj.domain.Term str : lists) {
if (!"w".equalsIgnoreCase(str.getNatureStr())) {
list.add(str.getName());
}
}
return list;
}
@Override
public boolean isRepeat(String text, double rate, int contentLegth) {
return WordUtil.isRepeat(text, rate, contentLegth);
}
private void loadXgbosost(String path) {
try {
booster = XGBoost.loadModel(path);
} catch (XGBoostError e) {
logger.error("loadd model error:", e);
}
}
private void loadmapterms(String file) {
mapTerms = new HashMap<>();
BufferedReader in = null;
String json = null;
try {
in = new BufferedReader(new InputStreamReader(
new FileInputStream(file)));
String line = in.readLine();
while (line != null) {
String[] lines = line.split("\t");
Terms terms = JSONObject.parseObject(lines[1], Terms.class);
mapTerms.put(lines[0], terms);
line = in.readLine();
}
} catch (Exception e) {
logger.error("load map error:", e);
e.printStackTrace();
} finally {
if (in != null) {
try {
in.close();
} catch (IOException e) {
logger.error("close error:", e);
}
}
}
}
}
spring配置:
<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans"
xmlns:context="http://www.springframework.org/schema/context"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:aop="http://www.springframework.org/schema/aop"
xmlns:tx="http://www.springframework.org/schema/tx"
xsi:schemaLocation="
http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans-2.5.xsd
http://www.springframework.org/schema/context http://www.springframework.org/schema/context/spring-context-2.5.xsd
http://www.springframework.org/schema/aop http://www.springframework.org/schema/aop/spring-aop-2.5.xsd
http://www.springframework.org/schema/tx http://www.springframework.org/schema/tx/spring-tx-2.5.xsd">
<bean id="xgboostTxtPrediction" class="com.dianping.text.classify.serviceimpl.ServiceImpl" init-method="init">
</bean>
</beans>
web.xml就不写了,总体来说,模型部署的时候一定要注意gcc环境的安装,否则部署会报错,还有mac 下编译的xgboost包并不适合在centos、乌班图上适应
否则也会报错,在centos上或者乌班图上部署一定要重新编译xgboost,把编译好的jar包maven到你们自己公司的仓库中去。