public static String getParamFromModelFile(File modelFile, String jarName) {
if (modelFile == null || !modelFile.exists()) {
logger.error(“解析AI模型的参数,jar文件不存在”);
return null;
}
if (StringUtil.isBlank(jarName)) {
logger.error(String.format(“解析ai模型参数模型文件所在路径或模型名称为空, 路径: %s, 名称: %s”, modelFile.getAbsolutePath(), jarName));
return null;
}
if (jarName.endsWith(".jar")) { // 解析jar模型的参数
String[] params = getParamArrayFromJar(modelFile, jarName);
if (null == params || params.length <= 0) {
logger.error(“解析AI模型参数为空”);
return null;
}
//示例: {“Species”:“string”,“sepallength”:“string”}
JSONObject obj = new JSONObject();
for (String param : params) {
obj.put(param, “string”);
}
return obj.toJSONString();
} else if (jarName.endsWith(".pmml") || jarName.endsWith(".xml")){ //解析pmml 和 xml 模型的参数
return readXmlOrPmmlParams(modelFile);
} else {
logger.error(“解析模型参数,模型文件格式错误”);
return null;
}
}
/**
* @Author: mahongfei
* @description: 读取pmml、xml中的参数
*/
public static String readXmlOrPmmlParams(File file) {
try {
StringBuilder sb = new StringBuilder();
SAXReader reader = new SAXReader();
Document document = reader.read(file);
Element root = document.getRootElement();
List<Element> list = root.elements() ;
for (Element e:list){
if (e.getName().equals("DataDictionary")) {
Element dataDictionary = root.element("DataDictionary");//首先要知道自己要操作的节点。
List<Element> dataList = dataDictionary.elements();
if (!dataList.isEmpty()) {
sb.append("{\"");
for (Element e1:dataList){
Attribute name = e1.attribute("name");
sb.append(name.getValue().toString()).append("\":\"");
Attribute type = e1.attribute("dataType");
sb.append(type.getValue().toString()).append("\",\"");
}
sb.delete(sb.length()-2, sb.length());
sb.append("}");
}
}
}
return sb.toString();
} catch (Exception e) {
logger.error(e.getLocalizedMessage(), e);
throw new IllegalStateException("读取参数文件失败");
}
}
/**
* 从 Maxim ai 平台jar模型 中解析参数
* jar中在特定类中存储参数字段信息,类的classname是 "NamesHolder_" + jarName
* 参数是静态的字符串数组, 常量名称是 VALUES
* @param jarFile 模型文件( .jar )
* @param jarName: jar的名称(不包括后缀)
*/
public static String[] getParamArrayFromJar(File jarFile, String jarName) {
if (!jarFile.getName().endsWith(".jar")) {
logger.error("文件格式错误,当前文件:" + jarFile.getName());
return null;
}
if (jarName.endsWith(".jar")) {
jarName = jarName.substring(0, jarName.length() - 4);
logger.info("解析AI模型参数, 模型名称以.jar结尾,去除后缀后: " + jarName);
}
try {
URLClassLoader classLoader = (URLClassLoader) ClassLoader.getSystemClassLoader();
Method addURL = URLClassLoader.class.getDeclaredMethod("addURL", URL.class);
addURL.setAccessible(true);
addURL.invoke(classLoader, jarFile.toURI().toURL());
String className = paramClassPrefix + jarName;
Class clazz = classLoader.loadClass(className);
Field paramField = clazz.getField(paramClassfield);
paramField.setAccessible(true);
Object o = paramField.get(clazz);
if (o == null) {
throw new IllegalArgumentException("解析AI模型参数出错,获得字段的值为空");
}
return (String[]) o;
} catch (ClassNotFoundException e) {
logger.error(String.format("解析Ai模型参数出错,ClassNotFoundException,模型路径: %s, 模型名称: %s", jarFile.getAbsolutePath(), jarName));
} catch (NoSuchMethodException e) {
logger.error(String.format("解析Ai模型参数出错,NoSuchMethodException,模型路径: %s, 模型名称: %s", jarFile.getAbsolutePath(), jarName));
} catch (IllegalAccessException e) {
logger.error(String.format("解析Ai模型参数出错,IllegalAccessException,模型路径: %s, 模型名称: %s", jarFile.getAbsolutePath(), jarName));
} catch (InvocationTargetException e) {
logger.error(String.format("解析Ai模型参数出错,InvocationTargetException,模型路径: %s, 模型名称: %s", jarFile.getAbsolutePath(), jarName));
} catch (MalformedURLException e) {
logger.error(String.format("解析Ai模型参数出错,MalformedURLException,模型路径: %s, 模型名称: %s", jarFile.getAbsolutePath(), jarName));
} catch (NoSuchFieldException e) {
logger.error(String.format("解析Ai模型参数出错,NoSuchFieldException,模型路径: %s, 模型名称: %s", jarFile.getAbsolutePath(), jarName));
} catch (Exception e) {
logger.error("解析AI模型出错: " + e.getLocalizedMessage(), e);
}
return null;
}
public static void main(String[] args) {
List<ModelResult> results = new ArrayList<>(100);
for (int i = 0; i < 2; i++) {
ModelResult result = new ModelResult();
result.setKey1("p1");
result.setKey2("p0");
result.setTpLable("1");
result.setFnLable("0");
int randomInt = (int) (1 + Math.random() * (10 - 1 + 1));
if (randomInt % 2 == 0) {
result.setLabel("0");
} else {
result.setLabel("1");
}
double randomDouble = new Random().nextDouble() * (1.0-0.0)+0.0;
result.setScore(randomDouble);
results.add(result);
}
List<Roc> rocCoordinates = getRocCoordinates(results);
System.out.println(rocCoordinates);
System.out.println("ks = " + getKs(rocCoordinates));
System.out.println("auc = " + getAuc(rocCoordinates));
System.out.println("解析jar参数字段: ");
String params = getParamFromModelFile(new File("D://XRT_0_AutoML_20190808_101915.jar"), "XRT_0_AutoML_20190808_101915.jar");
System.out.println(params);
}