最近基于bi-lstm做了一个辱骂识别模型准备部署到线上,之前打算用python 启动一个service 通过http请求来调用,发现公司平台是基于rpc服务的,开发部署起来也较蛋疼,今天下午闲来没事,看到tensorflow中有提供官方例子,通过python中训练好模型,用java来调用,刚刚好摸索了下,动手写了下代码,总算能在java中调用,废话不多说,直接看代码实现情况。
tensorflow版本情况:
In [1]: import tensorflow as tf
In [2]: tf.__version__
Out[2]: '1.2.1'
java需要1.8的版本
maven依赖:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.2.1</version>
</dependency>
参考资料:
tensorflow训练模型时候要保存的模型参数,主要有是三个,一个是模型输入的tensor大小,一个是dropout参数,一个是模型预测的logits(score/pred_y 表示name_scope下的pred_y)值,也就是y;模型保存为一个二进制文件,可以在java中加载:
if i%500==0 and i>0:
graph = tf.graph_util.convert_variables_to_constants(session, session.graph_def,
["keep_prob", "input_x", "score/pred_y"])
tf.train.write_graph(graph, ".", "/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph/graph.db",
as_text=False)
java代码如下,其中gettexttoid方法参考tensorflow中 tensorflow.contrib.keras.preprocessing.sequence.pad_sequences下的实现,用于做文本预测:
package com.meituan.test;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
public class TensorflowEx {
private static String path = "/Users/shuubiasahi/Documents/python/credit-tftextclassify-abuse/vocab_cnews.txt";
private static Map<String, Integer> word_to_id = new HashMap<String, Integer>();
static {
try {
BufferedReader buffer = null;
buffer = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
int i=0;
String line=buffer.readLine().trim();
while(line!=null){
word_to_id.put(line, i++);
line=buffer.readLine().trim();
}
buffer.close();
} catch (Exception e) {
}
System.out.println("word_to_id.size is:"+word_to_id.size());
}
public static void main(String[] args) {
byte[] graphDef = readAllBytesOrExit(Paths.get(
"/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph",
"graph.db"));
Graph g = new Graph();
g.importGraphDef(graphDef);
Session sess = new Session(g);
String text="艹你麻痹的垃圾店家,劳资点的香干回锅肉套餐,你他麻痹炒个香干炒肉过来凑数,套餐内所有的东西都没看到,还尼玛口口声声说退款?退你麻痹,留着给你家人买棺材用吧,狗日的东西!";
int[][] arr=gettexttoid(text);
Tensor input = Tensor.create(arr);
Tensor x = Tensor.create(1.0f);
Tensor result = sess.runner().feed("input_x", input).feed("keep_prob", x)
.fetch("score/pred_y").run().get(0);
long[] rshape = result.shape();
/*
* 模型为一个二分类模型,故nlabels=2,模型预测一条数据故batchsize=1
* 预测出来是一个1*2的数组,一条数据有两个概率
*
* */
int nlabels = (int) rshape[1];
int batchSize = (int) rshape[0];
float[][] logits = result.copyTo(new float[batchSize][nlabels]);
System.out.println("辱骂模型识别的概率为:"+logits[0][1]);
System.out.println("sucess");
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: "
+ e.getMessage());
System.exit(1);
}
return null;
}
/*
* 序列默人长度为300
* */
public static int[][] gettexttoid(String text){
int[][] xpad = new int[1][300];
if(StringUtils.isBlank(text)){
return xpad;
}
char[] chs=text.trim().toLowerCase().toCharArray();
List<Integer> list=new ArrayList<Integer>();
for(int i=0;i<chs.length;i++){
String element=Character.toString(chs[i]);
if(word_to_id.containsKey(element)){
list.add(word_to_id.get(element));
}
}
if(list.size()==0){
return xpad;
}
int size = list.size();
Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);
int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
if(size<=300){
System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);
}else{
System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
/*
* 自定义长度
* */
public static int[][] gettexttoid(String text,int maxlen){
if(maxlen<1){
throw new IllegalArgumentException("maxlen长度必须大于等于1");
}
int[][] xpad = new int[1][maxlen];
if(StringUtils.isBlank(text)){
return xpad;
}
char[] chs=text.trim().toLowerCase().toCharArray();
List<Integer> list=new ArrayList<Integer>();
for(int i=0;i<chs.length;i++){
String element=Character.toString(chs[i]);
if(word_to_id.containsKey(element)){
list.add(word_to_id.get(element));
}
}
if(list.size()==0){
return xpad;
}
int size = list.size();
Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);
int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
if(size<=maxlen){
System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);
}else{
System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
}
结果对比:
java结果:

python启动的service结果:

结果一致,下周计划写个java service项目,把模型部署上线。
不过我碰到过问题,在java中做预测,1秒最多只能预测十来条文本,这感觉太慢了,不知道什么原因,我机器用的cpu,不知道是否要用gpu做预测,有知道的告诉我
联系我 xuxu_ge