CatBoost模型的Java推理相比LightGBM会简单许多,无需转换成pmml格式,直接用官方的Java-package即可。
最主要的是,它直接支持字符串类型的类别特征,无需做各种编码转换,简直不要太6。
参考文档:https://catboost.ai/en/docs/concepts/java-package

一,Java项目添加Maven依赖
注意version与python中的一致
<!-- https://mvnrepository.com/artifact/ai.catboost/catboost-prediction -->
<dependency>
<groupId>ai.catboost</groupId>
<artifactId>catboost-prediction</artifactId>
<version>1.0.6</version>
</dependency>
二,Python端训练CatBoost模型
此处以adult数据集的二分类问题为例。
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import datetime
import numpy as np
import pandas as pd
import plotly.express as px
import catboost as cb
def printlog(info):
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("\n"+"=========="*8 + "%s"%nowtime)
print(info+'...\n\n')
#================================================================================
# 一,准备数据
#================================================================================
printlog("step1: preparing data...")
adult = fetch_openml(name = "adult",version=1,as_frame=True)
label_col = "target"
dfdata = adult["data"]
dfdata["target"] = (adult["target"]==">50K").astype("float")
cat_features = dfdata.drop("target",axis=1
).select_dtypes("category").columns.values.tolist()
cat_features.sort()
num_features = [col for col in dfdata.columns if col not in cat_features+[label_col]]
num_features.sort()
dfdata[cat_features] = dfdata[cat_features].astype(str)
dfdata[cat_features].fillna("missed")
dfdata = dfdata[num_features+cat_features+[label_col]]
dftrain_val,dftest = train_test_split(dfdata,test_size=0.3)
dftrain,dfval = train_test_split(dftrain_val,test_size=0.3)
# 整理成Pool
pool_train = cb.Pool(data = dftrain.drop(label_col,axis=1),
label = dftrain[label_col], cat_features=cat_features)
pool_val = cb.Pool(data = dfval.drop(label_col,axis=1),

本文介绍如何使用CatBoost模型进行Java推理,无需转换为PMML格式。通过成人数据集示例,展示了从训练模型到Java端预测的完整流程。
最低0.47元/天 解锁文章
681

被折叠的 条评论
为什么被折叠?



