之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。
思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。
工作思路:1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;
但是使用python 的tf interpreter 调用tflite文件就已经出现精度下降的问题,android端部署也是一样。
1.网络结构
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def ttnet(images, num_classes=10, is_training=False,
dropout_keep_prob=0.5,
prediction_fn=slim.softmax,
scope='TtNet'):
end_points = {}
with tf.variable_scope(scope, 'TtNet', [images, num_classes]):
net = slim.conv2d(images, 32, [3, 3], scope='conv1')
# net = slim.conv2d(images, 64, [3, 3], scope='conv1_2')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1')
# net = slim.conv2d(net, 128, [3, 3], scope='conv2_1')
net = slim.conv2d(net, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
net = slim.conv2d(net, 128, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool3')
net = slim.conv2d(net, 256, [3, 3], scope='conv4')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool4')
net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2')
# net = slim.conv2d(net, 512, [3, 3], scope='conv5