deeplabv3+训练的ckpt模型Android端部署

本文详细指导如何使用 DeeplabV3+ 的训练模型进行量化处理,包括修改train.py和export_model.py脚本,然后转换为pb模型,进一步转化为TFLite模型,并在Android端进行部署。步骤涉及模型量化、模型导出和在实际应用中的适配。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

deeplabv3+训练的ckpt模型Android端部署

接上一篇文章用deeplabV3+训练自己的数据集并导出.pb模型进行可视化测试

一.对deeplabv3+训练的ckpt模型做量化

参考官方文档
修改train.py文件

flags.DEFINE_string('train_logdir', "./datasets/VOC/out",
                    'Where the checkpoint and logs are stored.')
flags.DEFINE_float('base_learning_rate', 3e-5,
                   'The base learning rate for model training.')
flags.DEFINE_integer('training_number_of_steps', 5000,
                     'The number of steps used for training')  ###########  迭代次数
flags.DEFINE_string('tf_initial_checkpoint', "/home/dreamdeck/Downloads/Tensorflow/models-master/research/deeplab/datasets/VOC/trainout/model.ckpt",
                   'The initial checkpoint in tensorflow format.') ##上一步训练好的模型文件
flags.DEFINE_integer(
    'quantize_delay_step', 0,
    'Steps to start quantized training. If < 0, will not quantize model.')

运行train.py

二.量化后的.ckpt模型转.pb模型

修改export_model.py

flags.DEFINE_string('checkpoint_path', "./datasets/VOC/out/model.ckpt", 'Checkpoint path')

flags.DEFINE_string('export_path', "./datasets/VOC/out/frozen_inference_graph.pb",
                    'Path to output Tensorflow frozen graph.')

flags.DEFINE_integer('num_classes', 2, 'Number of classes.')

flags.DEFINE_multi_integer('crop_size', [513, 513],
                           'Crop size [height, width].')
flags.DEFINE_integer(
    'quantize_delay_step', 0,
    'Steps to start quantized training. If < 0, will not quantize model.')

运行export_model.py

三. pb模型转tflite

修改convert_to_tflite.py文件

flags.DEFINE_string('quantized_graph_def_path', "./datasets/VOC/out/frozen_inference_graph.pb",
                    'Path to quantized graphdef.')
flags.DEFINE_string('output_tflite_path', "./datasets/VOC/out/frozen_inference_graph.tflite", 'Output TFlite model path.')
flags.DEFINE_string(
    'input_tensor_name', "MobilenetV2/MobilenetV2/input:0",
    'Input tensor to TFlite model. This usually should be the input tensor to '
    'model backbone.'
)
flags.DEFINE_string(
    'output_tensor_name', 'ArgMax:0',
    'Output tensor name of TFlite model. By default we output the raw semantic '
    'label predictions.'
)
flags.DEFINE_string(
    'test_image_path', "./datasets/VOC/JPEGImages/ADE_train_00000004.jpg",
    'Path to an image to test the consistency between input graphdef / '
    'converted tflite model.'
)

运行convert_to_tflite.py

四.tflite模型部署

1.下载tensorflow-lite-examples-android并用android studio软件打开
2.将tflite模型文件copy到tensorflow-lite-examples-android-master/tensorflow-lite-examples-android-master/example-image-segmentation/src/main/assets文件夹中
3.注释掉download.gradle文件中的

apply from:'../buildscripts/network.gradle'

//def targetModelFile = "src/main/assets/deeplabv3_257_mv_gpu.tflite"
//def modelDownloadUrl = "https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1?lite-format=tflite"
//
//// Required to download TF Lite model for Smart Reply.
//task downloadLiteModel(type: DownloadUrlTask) {
//    sourceUrl = modelDownloadUrl
//    target = file(targetModelFile)
//}

//preBuild.dependsOn downloadLiteModel

修改ImageSegmentationModelExecutor文件

segmentationMasks = ByteBuffer.allocateDirect(1 * imageSize * imageSize * 8)
    private const val imageSegmentationModel = "frozen_inference_graph.tflite"   
    private const val imageSize = 513
    const val NUM_CLASSES = 2
    
    val segmentColors = IntArray(NUM_CLASSES)
    val labelsArrays = arrayOf(
      "background", "sky"
    )
  private fun extraMaskAndItemsFromByteBuffer(buffer: ByteBuffer,
                                              width: Int,
                                              height: Int,
                                              colors: IntArray): Pair<Bitmap, Set<String>> {
    val start = System.currentTimeMillis()

    val pixels = width * height
    val data = IntArray(pixels)
    val itemsFound = HashSet<String>()
    val cacheInMemory = LongArray(width * height)

    buffer.rewind()
    buffer.asLongBuffer().get(cacheInMemory)

    for (y in 0 until height) {
      for (x in 0 until width) {
        var itemIndex = 0

        var c = cacheInMemory[y * width + x].toInt()
        data[y * width + x] = colors[c]
        itemIndex = c


        itemsFound.add(labelsArrays[itemIndex])
      }
    }

    val end = System.currentTimeMillis()

    Logger.debug("convert $pixels pixels' buffer in ${end - start} ms")

    val maskBitmap = com.dailystudio.devbricksx.utils.ImageUtils.intArrayToBitmap(
      data, imageSize, imageSize)

    return Pair(maskBitmap, itemsFound)
  }

修改local.properties文件为自己对应的sdk和ndk

ndk.dir=/home/dreamdeck/Android/Sdk/ndk/22.0.7026061
sdk.dir=/home/dreamdeck/Android/Sdk

修改ImageSegmentationCameraFragment文件

const val MODEL_IMAGE_SIZE = 513

修改TFLiteImageHelper.kt文件

package com.dailystudio.tflite.example.common.image

import android.graphics.Bitmap
import android.util.Log
import java.nio.ByteBuffer
import java.nio.ByteOrder

object TFLiteImageHelper {

    fun bitmapToByteBuffer(
        bitmapIn: Bitmap,
        width: Int,
        height: Int,
        mean: Float = 0.0f,
        std: Float = 255.0f
    ): ByteBuffer {
        val inputImage = ByteBuffer.allocateDirect(1 * width * height * 3)
        inputImage.order(ByteOrder.nativeOrder())
        inputImage.rewind()

        val intValues = IntArray(width * height)
        bitmapIn.getPixels(intValues, 0, width, 0, 0, width, height)
        var pixel = 0
        for (y in 0 until height) {
            for (x in 0 until width) {
                val value = intValues[pixel++]

                // Normalize channel values to [-1.0, 1.0]. This requirement varies by
                // model. For example, some models might require values to be normalized
                // to the range [0.0, 1.0] instead.
//                val a = value shr 16 and 0xFF
//                val b = value shr 8 and 0xFF
//                val c = value and 0xFF
//                inputImage.putFloat(((value shr 16 and 0xFF) - mean) / std)
//                inputImage.putFloat(((value shr 8 and 0xFF) - mean) / std)
//                inputImage.putFloat(((value and 0xFF) - mean) / std)
//                Log.d("TFLiteImageHelper", "y= $y, x= $x, a= $a, b= $b, c= $c")
                inputImage.put(((value shr 16 and 0xFF).toByte()))
                inputImage.put(((value shr 8 and 0xFF).toByte()))
                inputImage.put(((value and 0xFF).toByte()))
            }
        }

        inputImage.rewind()
        return inputImage
    }

}

4.连接手机并打开手机的开发者模式,打开usb调试.
5.运行example-image-segmentation

<think>好的,我现在需要帮助用户解决如何在DeepLabV3+模型训练中断后继续训练的问题。用户提到了使用TensorFlow或PyTorch,所以我应该分别考虑这两种框架下的解决方案。 首先,我需要回忆一下断点续训的一般步骤。无论是TensorFlow还是PyTorch,核心都是保存和加载模型的状态,包括模型参数、优化器状态、以及当前的epoch和损失值等信息。这样在恢复训练时,可以从中断的地方继续,而不用从头开始。 对于TensorFlow,我记得可以使用回调函数ModelCheckpoint来定期保存模型。用户需要在训练时设置这个回调,这样在中断后,可以加载最新的检查点。然后使用model.fit的initial_epoch参数来指定从哪个epoch开始继续训练。需要注意的是,TensorFlow的检查点可能只保存模型权重,所以优化器的状态可能需要额外处理,特别是在使用自定义优化器或学习率调度器时。 然后是PyTorch部分。PyTorch通常通过torch.save来保存整个模型的状态字典,包括模型、优化器、epoch等信息。恢复训练时,需要先加载这些状态,然后重新初始化模型和优化器,并将状态加载进去。之后在训练循环中,从保存的epoch开始继续训练。这里要确保数据加载器的状态也能正确恢复,可能需要保存随机数生成器的状态,否则数据顺序可能不同,影响训练结果。 接下来,我需要考虑用户可能遇到的问题。例如,检查点没有保存优化器状态,或者保存的格式不正确。这时候可能需要检查保存和加载的代码是否正确,确保所有必要的部分都被包含。另外,如果用户使用的是分布式训练,情况会更复杂,可能需要额外的处理步骤。 然后,我需要验证一下步骤是否正确。比如,在TensorFlow中,是否真的需要设置initial_epoch,或者是否有其他参数需要注意。在PyTorch中,加载模型后是否必须调用model.train()来确保模型处于训练模式。同时,提醒用户注意学习率调度器的状态是否被正确保存和恢复,因为有些调度器是根据epoch来调整学习率的,如果epoch没有恢复,可能会导致学习率不正确。 用户可能还想知道如何自动保存检查点,比如每隔几个epoch保存一次,或者只在验证损失改善时保存。这时候需要指导他们在回调函数或训练循环中添加相应的逻辑。 最后,生成相关问题时要覆盖常见的后续问题,比如如何调整学习率、处理不同数据集、多GPU训练下的断点续训,以及检查点无法加载时的解决方法。这些问题可以帮助用户深入理解断点续训的各个方面,并解决可能遇到的困难。</think>### 如何在DeepLabV3+训练被中断后继续训练 #### 一、核心原理 断点续训的核心是通过保存和加载训练状态实现连续性。需要保存以下要素: 1. 模型权重:记录网络参数$W_t$和偏置$b_t$ 2. 优化器状态:动量参数$v_t$、梯度二阶矩估计$m_t$等(如Adam优化器) 3. 训练元数据:当前epoch数$n$、学习率$\eta$、损失值$L$等 #### 二、TensorFlow实现方法 ```python # 保存检查点 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) manager = tf.train.CheckpointManager(checkpoint, directory=&#39;./ckpt&#39;, max_to_keep=3) # 训练循环中 for epoch in range(initial_epoch, total_epochs): # ...训练步骤... manager.save() # 恢复训练 latest_ckpt = tf.train.latest_checkpoint(&#39;./ckpt&#39;) checkpoint.restore(latest_ckpt) initial_epoch = int(latest_ckpt.split(&#39;-&#39;)[1]) # 解析已完成的epoch数 ``` #### 三、PyTorch实现方法 ```python # 保存状态 torch.save({ &#39;epoch&#39;: epoch, &#39;model_state_dict&#39;: model.state_dict(), &#39;optimizer_state_dict&#39;: optimizer.state_dict(), &#39;loss&#39;: loss, }, &#39;checkpoint.pth&#39;) # 加载状态 checkpoint = torch.load(&#39;checkpoint.pth&#39;) model.load_state_dict(checkpoint[&#39;model_state_dict&#39;]) optimizer.load_state_dict(checkpoint[&#39;optimizer_state_dict&#39;]) start_epoch = checkpoint[&#39;epoch&#39;] + 1 ``` #### 四、注意事项 1. 数据加载器状态恢复:需保存随机数生成器状态`torch_rng_state`和`numpy_rng_state` 2. 学习率调度器状态:需保存`lr_scheduler.state_dict()` 3. 混合精度训练:需额外保存梯度缩放器状态`scaler.state_dict()` 4. 分布式训练:需使用`module.state_dict()`获取完整模型参数[^1]
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值