deeplabv3+训练的ckpt模型Android端部署
接上一篇文章用deeplabV3+训练自己的数据集并导出.pb模型进行可视化测试
一.对deeplabv3+训练的ckpt模型做量化
参考官方文档
修改train.py文件
flags.DEFINE_string('train_logdir', "./datasets/VOC/out",
'Where the checkpoint and logs are stored.')
flags.DEFINE_float('base_learning_rate', 3e-5,
'The base learning rate for model training.')
flags.DEFINE_integer('training_number_of_steps', 5000,
'The number of steps used for training') ########### 迭代次数
flags.DEFINE_string('tf_initial_checkpoint', "/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC/trainout/model.ckpt",
'The initial checkpoint in tensorflow format.') ##上一步训练好的模型文件
flags.DEFINE_integer(
'quantize_delay_step', 0,
'Steps to start quantized training. If < 0, will not quantize model.')
运行train.py
二.量化后的.ckpt模型转.pb模型
修改export_model.py
flags.DEFINE_string('checkpoint_path', "./datasets/VOC/out/model.ckpt", 'Checkpoint path')
flags.DEFINE_string('export_path', "./datasets/VOC/out/frozen_inference_graph.pb",
'Path to output Tensorflow frozen graph.')
flags.DEFINE_integer('num_classes', 2, 'Number of classes.')
flags.DEFINE_multi_integer('crop_size', [513, 513],
'Crop size [height, width].')
flags.DEFINE_integer(
'quantize_delay_step', 0,
'Steps to start quantized training. If < 0, will not quantize model.')
运行export_model.py
三. pb模型转tflite
修改convert_to_tflite.py文件
flags.DEFINE_string('quantized_graph_def_path', "./datasets/VOC/out/frozen_inference_graph.pb",
'Path to quantized graphdef.')
flags.DEFINE_string('output_tflite_path', "./datasets/VOC/out/frozen_inference_graph.tflite", 'Output TFlite model path.')
flags.DEFINE_string(
'input_tensor_name', "MobilenetV2/MobilenetV2/input:0",
'Input tensor to TFlite model. This usually should be the input tensor to '
'model backbone.'
)
flags.DEFINE_string(
'output_tensor_name', 'ArgMax:0',
'Output tensor name of TFlite model. By default we output the raw semantic '
'label predictions.'
)
flags.DEFINE_string(
'test_image_path', "./datasets/VOC/JPEGImages/ADE_train_00000004.jpg",
'Path to an image to test the consistency between input graphdef / '
'converted tflite model.'
)
运行convert_to_tflite.py
四.tflite模型部署
1.下载tensorflow-lite-examples-android并用android studio软件打开
2.将tflite模型文件copy到tensorflow-lite-examples-android-master/tensorflow-lite-examples-android-master/example-image-segmentation/src/main/assets
文件夹中
3.注释掉download.gradle
文件中的
apply from:'../buildscripts/network.gradle'
//def targetModelFile = "src/main/assets/deeplabv3_257_mv_gpu.tflite"
//def modelDownloadUrl = "https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1?lite-format=tflite"
//
//// Required to download TF Lite model for Smart Reply.
//task downloadLiteModel(type: DownloadUrlTask) {
// sourceUrl = modelDownloadUrl
// target = file(targetModelFile)
//}
//preBuild.dependsOn downloadLiteModel
修改ImageSegmentationModelExecutor
文件
segmentationMasks = ByteBuffer.allocateDirect(1 * imageSize * imageSize * 8)
private const val imageSegmentationModel = "frozen_inference_graph.tflite"
private const val imageSize = 513
const val NUM_CLASSES = 2
val segmentColors = IntArray(NUM_CLASSES)
val labelsArrays = arrayOf(
"background", "sky"
)
private fun extraMaskAndItemsFromByteBuffer(buffer: ByteBuffer,
width: Int,
height: Int,
colors: IntArray): Pair<Bitmap, Set<String>> {
val start = System.currentTimeMillis()
val pixels = width * height
val data = IntArray(pixels)
val itemsFound = HashSet<String>()
val cacheInMemory = LongArray(width * height)
buffer.rewind()
buffer.asLongBuffer().get(cacheInMemory)
for (y in 0 until height) {
for (x in 0 until width) {
var itemIndex = 0
var c = cacheInMemory[y * width + x].toInt()
data[y * width + x] = colors[c]
itemIndex = c
itemsFound.add(labelsArrays[itemIndex])
}
}
val end = System.currentTimeMillis()
Logger.debug("convert $pixels pixels' buffer in ${end - start} ms")
val maskBitmap = com.dailystudio.devbricksx.utils.ImageUtils.intArrayToBitmap(
data, imageSize, imageSize)
return Pair(maskBitmap, itemsFound)
}
修改local.properties
文件为自己对应的sdk和ndk
ndk.dir=/home/dreamdeck/Android/Sdk/ndk/22.0.7026061
sdk.dir=/home/dreamdeck/Android/Sdk
修改ImageSegmentationCameraFragment文件
const val MODEL_IMAGE_SIZE = 513
修改TFLiteImageHelper.kt文件
package com.dailystudio.tflite.example.common.image
import android.graphics.Bitmap
import android.util.Log
import java.nio.ByteBuffer
import java.nio.ByteOrder
object TFLiteImageHelper {
fun bitmapToByteBuffer(
bitmapIn: Bitmap,
width: Int,
height: Int,
mean: Float = 0.0f,
std: Float = 255.0f
): ByteBuffer {
val inputImage = ByteBuffer.allocateDirect(1 * width * height * 3)
inputImage.order(ByteOrder.nativeOrder())
inputImage.rewind()
val intValues = IntArray(width * height)
bitmapIn.getPixels(intValues, 0, width, 0, 0, width, height)
var pixel = 0
for (y in 0 until height) {
for (x in 0 until width) {
val value = intValues[pixel++]
// Normalize channel values to [-1.0, 1.0]. This requirement varies by
// model. For example, some models might require values to be normalized
// to the range [0.0, 1.0] instead.
// val a = value shr 16 and 0xFF
// val b = value shr 8 and 0xFF
// val c = value and 0xFF
// inputImage.putFloat(((value shr 16 and 0xFF) - mean) / std)
// inputImage.putFloat(((value shr 8 and 0xFF) - mean) / std)
// inputImage.putFloat(((value and 0xFF) - mean) / std)
// Log.d("TFLiteImageHelper", "y= $y, x= $x, a= $a, b= $b, c= $c")
inputImage.put(((value shr 16 and 0xFF).toByte()))
inputImage.put(((value shr 8 and 0xFF).toByte()))
inputImage.put(((value and 0xFF).toByte()))
}
}
inputImage.rewind()
return inputImage
}
}
4.连接手机并打开手机的开发者模式,打开usb调试.
5.运行example-image-segmentation