pytorch版本Deeplabv3+网络模型格式转换(pth转pt)

为了进一步使用c++调用deeplabv3+模型,使用trace将pytorch训练生成的.pth格式转为.pt

参考:https://github.com/shanson123/ORB_SLAM2_DeeplabV3/blob/master/DeeplabV3/create_deeplabv3.py

在predict.py文件中添加:

    with torch.no_grad():
        model = model.eval()
        for img_path in tqdm(image_files):
            ext = os.path.basename(img_path).split('.')[-1]
            img_name = os.path.basename(img_path)[:-len(ext)-1]
            img = Image.open(img_path).convert('RGB')
            img = transform(img).unsqueeze(0) # To tensor of NCHW
            img = img.to(device)
            
            pred = model(img).max(1)[1].cpu().numpy()[0] # HW
            colorized_preds = decode_fn(pred).astype('uint8')
            colorized_preds = Image.fromarray(colorized_preds)
            if opts.save_val_results_to:
                colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))

        #pth转pt
        traced_model = torch.jit.trace(model.module, img.to(device))
        traced_model.save("DeeplabV3plus.pt")

注意,如果写成 traced_model = torch.jit.trace(model, img.to(device)),会出现下图的报错:
Could not export Python function call ‘Scatter’.
请添加图片描述

### DeepLab 图像分割模型使用教程 #### 1. 模型概述 DeepLab 是一种用于语义图像分割的强大工具,通过引入空洞卷积(Atrous Convolution)、多尺度上下文聚合以及空间金字塔池化模块来提升性能。该系列经历了多个版本的发展,从 Deeplab v1 到最新的 Deeplab v3+ 不断改进结构设计以提高精度和效率[^2]。 #### 2. 安装环境准备 为了能够顺利运行并训练 DeepLab 模型,在本地计算机上需安装必要的依赖库。推荐采用 Anaconda 创建虚拟环境,并按照官方文档中的说明完成 TensorFlow 或 PyTorch 的安装配置工作。对于具体操作步骤可以参考 GitHub 上的相关项目页面提供的指导信息[^1]。 #### 3. 数据集获取与预处理 针对特定应用场景选择合适的数据集非常重要。常用公开数据集如 PASCAL VOC、COCO 等可用于初步实验测试;而对于定制化的业务需求,则可能需要自行收集标注图片作为输入素材。在实际应用前还需对原始图像做标准化变换、裁剪缩放等一系列预处理措施以便更好地适应算法框架的要求[^4]。 #### 4. 训练过程详解 利用给定的开源代码仓库 `train-DeepLab` 可快速搭建起基于 Python 和 TensorFlow 实现的 DeepLabV3+/MobileNet 版本实例程序。用户可以根据自身硬件条件调整超参数设置(比如批次大小 Batch Size, 学习率 Learning Rate 等),并通过 TensorBoard 工具实时监控损失函数变化趋势从而优化调参策略。 ```bash # 下载并解压Pascal VOC数据集 wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar tar xf VOCtrainval_11-May-2012.tar # 进入项目目录执行脚本启动训练流程 cd /path/to/train-deeplab/ python deeplab_train.py --dataset_dir=./datasets/pascal_voc_seg \ --model_variant="mobilenet_v2" \ --atrous_rates=6 \ --output_stride=16 \ --train_crop_size="513,513" ``` #### 5. 部署到移动设备 当完成了模型训练之后,如果希望将其应用于移动端应用程序开发当中去的话,那么就需要考虑如何有效地压缩量化权重文件使之能够在资源受限环境下正常运作。这里可以通过 ONNX Runtime 来换导出经过优化后的 .pt 格式的深度神经网络描述符至 Android 平台之上进行推理计算。 ```python import torch.onnx as onnx from torchvision import models dummy_input = torch.randn(1, 3, 224, 224).cuda() model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval().cuda() onnx.export(model, dummy_input, "deeplabv3.onnx", export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output']) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值