CatBoost的Java端推理

CatBoost模型的Java推理相比LightGBM会简单许多,无需转换成pmml格式,直接用官方的Java-package即可。

最主要的是,它直接支持字符串类型的类别特征,无需做各种编码转换,简直不要太6。

参考文档:https://catboost.ai/en/docs/concepts/java-package

014de0c0445dfc551db919956c537ac2.png

一,Java项目添加Maven依赖

注意version与python中的一致

<!-- https://mvnrepository.com/artifact/ai.catboost/catboost-prediction -->
<dependency>
    <groupId>ai.catboost</groupId>
    <artifactId>catboost-prediction</artifactId>
    <version>1.0.6</version>
</dependency>

二,Python端训练CatBoost模型

此处以adult数据集的二分类问题为例。

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split 
import datetime 

import numpy as np 
import pandas as pd 
import plotly.express as px 


import catboost as cb 
def printlog(info):
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n"+"=========="*8 + "%s"%nowtime)
    print(info+'...\n\n')
    
    
#================================================================================
# 一,准备数据
#================================================================================
printlog("step1: preparing data...")


adult = fetch_openml(name = "adult",version=1,as_frame=True)

label_col = "target"
dfdata = adult["data"]
dfdata["target"] = (adult["target"]==">50K").astype("float")

cat_features = dfdata.drop("target",axis=1
                ).select_dtypes("category").columns.values.tolist()
cat_features.sort()

num_features = [col for col in dfdata.columns if col not in cat_features+[label_col]]
num_features.sort() 

dfdata[cat_features] = dfdata[cat_features].astype(str)
dfdata[cat_features].fillna("missed") 

dfdata = dfdata[num_features+cat_features+[label_col]]

dftrain_val,dftest = train_test_split(dfdata,test_size=0.3)
dftrain,dfval = train_test_split(dftrain_val,test_size=0.3)

# 整理成Pool
pool_train = cb.Pool(data = dftrain.drop(label_col,axis=1), 
                     label = dftrain[label_col], cat_features=cat_features)
pool_val = cb.Pool(data = dfval.drop(label_col,axis=1), 
                     label = dfval[label_col], cat_features=cat_features)
pool_test = cb.Pool(data = dftest.drop(label_col,axis=1), 
            label = dftest[label_col], cat_features=cat_features)


#================================================================================
# 二,设置参数
#================================================================================
printlog("step2: setting parameters...")
                               
iterations = 1000
early_stopping_rounds = 200

params = dict(
    loss_function = "Logloss",
    eval_metric = "AUC",
    random_seed = 42,
    logging_level = 'Silent',
    use_best_model = True,
    nan_mode = 'Min',
    ###
    learning_rate = 0.05,
    depth = 6,
    min_data_in_leaf = 10,
    one_hot_max_size = 5,      #类别数量多于此数将使用ordered target statistics
    boosting_type = "Ordered", #Ordered 或者Plain
    max_ctr_complexity = 2,    #特征组合的最大特征数量,设置为1取消特征组合,
)



model = cb.CatBoostClassifier(
        iterations = iterations,
        early_stopping_rounds = early_stopping_rounds,
        **params
    )


#================================================================================
# 三,训练模型
#================================================================================
printlog("step3: training model...")


model.fit(
    pool_train,
    eval_set=pool_val,
    plot=True
)



#================================================================================
# 四,评估模型
#================================================================================
printlog("step4: evaluating model ...")


#feature importance 
dfimportance = model.get_feature_importance(prettified=True) 
dfimportance = dfimportance.sort_values(by = "Importances").iloc[-20:]
fig_importance = px.bar(dfimportance,x="Importances",y="Feature Id",title="Feature Importance")

display(dfimportance)
display(fig_importance)


#score distribution
y_test_prob = model.predict_proba(dftest.drop(label_col,axis = 1))[:,-1]
fig_hist = px.histogram(
    x=y_test_prob,color =dftest[label_col],  nbins=50,
    title = "Score Distribution",
    labels=dict(color='True Labels', x='Score')
)
fig_hist.show() 



#================================================================================
# 五,使用模型
#================================================================================
printlog("step5: using model ...")

y_pred_test = model.predict(dftest)
y_pred_test_prob = model.predict_proba(dftest)

print("y_pred_test:\n",y_pred_test[:10])
print("y_pred_test_prob:\n",y_pred_test_prob[:10])


#================================================================================
# 六,保存模型
#================================================================================
printlog("step6: saving model ...")

model_dir = 'adult_model.cbm'
model.save_model(model_dir)
model_loaded = cb.CatBoostClassifier()
model.load_model(model_dir)

得到的adult_model.cbm放入到java项目的resource目录下.

三,Java端推理预测封装

推理代码封装如下

package com.example.model;

import ai.catboost.CatBoostModel;
import ai.catboost.CatBoostPredictions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Map;

public class CatBoostClassifier {

    public CatBoostModel model;

    private static Logger LOG= LoggerFactory.getLogger(CatBoostClassifier.class);

    public CatBoostClassifier(String model_path) {
        try{
            InputStream model_file = Thread.currentThread()
                    .getContextClassLoader()
                    .getResourceAsStream(model_path);
            this.model = CatBoostModel.loadModel(model_file);

        }catch (Exception err){
            LOG.error("catboost-init-error", err);
            System.out.println(err.toString());
        };

    }

    public Double predict(Map<String,Double> num_features,Map<String,String> cat_features){
        try {
            String[] feature_names = this.model.getFeatureNames();
            assert (num_features.size() + cat_features.size()) == feature_names.length;

            float[] num_arr = new float[num_features.size()];
            String[] cat_arr = new String[cat_features.size()];

            int i = 0;
            int j = 0;
            for (String name : feature_names) {
                if (num_features.keySet().contains(name)) {
                    num_arr[i] = num_features.get(name).floatValue();
                    i += 1;
                } else {
                    assert cat_features.keySet().contains(name);
                    cat_arr[j] = cat_features.get(name);
                    j += 1;
                }
            }
            CatBoostPredictions prediction = this.model.predict(
                    num_arr,
                    cat_arr);
            Double prob = 1.0 - 1.0 / (1.0 + Math.exp(prediction.get(0, 0)));
            return prob;
        }catch (Exception err){
            LOG.error("catboost-predict-error", err);
            return null;
        }
    }
}

四,预测调用测试代码

import java.util.*;
import com.example.model.CatBoostClassifier;

public class CatBoostTest {
    public void testCatBoostClassifier(){
        CatBoostClassifier clf = new CatBoostClassifier("adult_model.cbm");
        
        Map<String,Double> num_features = new HashMap<String,Double>();
        Map<String,String> cat_features = new HashMap<String,String>();
        
        num_features.put("fnlwgt",236379.0);
        num_features.put("education-num",11.0);
        
        
        cat_features.put("workclass","Private");
        cat_features.put("sex","Male");
        cat_features.put("relationship","Husband");
        cat_features.put("race","White");
        cat_features.put("occupation","Craft-repair");
        cat_features.put("native-country","United-States");
        cat_features.put("marital-status","Married-civ-spouse");
        cat_features.put("hoursperweek","2");
        cat_features.put("education","Assoc-voc");
        cat_features.put("capitalloss","0");
        cat_features.put("capitalgain","0");
        cat_features.put("age","1");

        Double model_prob = clf.predict(num_features,cat_features);
        System.out.println(model_prob);
        System.out.println("sucessed!");
    }

}

f2bb77ad8d50224cd048efc00515520a.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值