tensorflow模型部署与python java API线上调用

项目github地址:bitcarmanlee easy-algorithm-interview-and-practice
欢迎大家star,留言,一起学习进步

0.前言

随着深度模型的普及,线上越来越多的模型换成了深度模型,与此对应的线上模型的部署与调用方式也会发生变化。下面我们就来介绍一下分别用python代码与java代码调用训练好的模型。

1.模型训练

首先我们训练一个简单的模型 y = 3 x + 0.1 y = 3x + 0.1 y=3x+0.1

#!/usr/bin/env python
# encoding: utf-8

"""
@author: wanglei
@time: 2020/7/30 上午9:25
"""

import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util

train_X = np.linspace(-1, 1, 100)
train_Y = 3*train_X + 0.1


X = tf.placeholder("float",name='X')
Y = tf.placeholder("float",name='Y')
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")

z = tf.multiply(X, W) + b
op = tf.add(tf.multiply(X, W),b,name='results')

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()

# 定义参数
training_epochs = 20
display_step = 2

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    plotdata = {'batchsize': [], 'loss': []}

    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):
            sess.run(optimizer, feed_dict={X:x, Y:y})

        if epoch % display_step == 0:
            loss = sess.run(cost, {X: train_X, Y: train_Y})
            print("Epoch:", epoch+1, "cost=", loss, "W=", sess.run(W), "b=",sess.run(b))

    saver.save(sess, 'mymodels/first')
    print("cost =", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))
    const_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["results"])

    with tf.gfile.FastGFile("mymodels/first.pb", mode='wb') as f:
        f.write(const_graph.SerializeToString())

print("Finished!")

上面的代码,会最终得到一个first.pd的模型。训练输出的结果如下

Epoch: 1 cost= 1.2481742 W= [1.2849432] b= [0.59788895]
Epoch: 3 cost= 0.10257308 W= [2.539106] b= [0.27416894]
Epoch: 5 cost= 0.006951133 W= [2.8805332] b= [0.1457994]
Epoch: 7 cost= 0.00046514577 W= [2.9691048] b= [0.11185519]
Epoch: 9 cost= 3.1108106e-05 W= [2.9920104] b= [0.10306594]
Epoch: 11 cost= 2.0804118e-06 W= [2.9979339] b= [0.10079286]
Epoch: 13 cost= 1.3938015e-07 W= [2.9994652] b= [0.10020526]
Epoch: 15 cost= 9.405541e-09 W= [2.999861] b= [0.10005322]
Epoch: 17 cost= 6.447206e-10 W= [2.9999638] b= [0.10001406]
Epoch: 19 cost= 5.3264854e-11 W= [2.9999897] b= [0.10000414]
cost = 2.6606464e-11 W= [2.9999924] b= [0.10000262]

最终得到的参数为 W = 2.9999924 , b = 0.10000262 W=2.9999924, b=0.10000262 W=2.9999924,b=0.10000262,与一次函数 y = 3 x + 0.1 y=3x+0.1 y=3x+0.1一致。

2.python代码调用模型

用python代码调用上面训练好的模型,示例如下

#!/usr/bin/env python
# encoding: utf-8

"""
@author: wanglei
@time: 2020/7/30 下午2:39
"""

import tensorflow as tf
from tensorflow.python.platform import gfile

sess = tf.Session()

with gfile.FastGFile('mymodels/first.pb', 'rb') as f:
    graph = tf.GraphDef()
    graph.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph, name='')

    sess.run(tf.global_variables_initializer())
    print(sess.run('weight:0'))
    print(sess.run('bias:0'))

    input_x = sess.graph.get_tensor_by_name('X:0')
    op = sess.graph.get_tensor_by_name('results:0')
    ret = sess.run(op, feed_dict={input_x: 2})
    print("ret is: ", ret)

运行上面的代码,输出如下

[2.9999924]
[0.10000262]
ret is:  [6.0999875]

3.java API调用

使用java解析调用模型的时候,需要加入tensorflow的官方依赖

    <properties>
        <grpc.version>1.20.0</grpc.version>
        <tensorflow.version>1.13.1</tensorflow.version>
    </properties>

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>${tensorflow.version}</version>
        </dependency>

因为我们代码里会使用到common-io里面的代码,所以加入相应的依赖

        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
            <version>2.4</version>
        </dependency>

然后开始解析模型并打分

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.FileInputStream;

/**
 * author: wanglei
 * create: 2020-07-30
 */
public class Demo1 {

    public static void test() throws Exception {
        Graph graph = new Graph();
        String filepath = "xxx/mymodels/first.pb";
        // 模型的图结构
        byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(filepath));
        graph.importGraphDef(graphBytes);

		// 根据图结构建议sess
        Session sess = new Session(graph);
        Tensor X = Tensor.create(2.0f);
        Tensor result = sess.runner().feed("X", X).fetch("results").run().get(0);
        System.out.println(result);
        float[] vector = new float[1];
        result.copyTo(vector);
        System.out.println("vector[0] is: " + vector[0]);
    }

    public static void main(String[] args) throws Exception {
        test();
    }
}

代码的输出为

FLOAT tensor with shape [1]
vector[0] is: 6.0999875
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值