pytorch-semseg源码解读test.py

本文介绍了一个深度学习模型在图像分割任务中的应用,包括模型加载、图像预处理、预测及后处理流程。通过调整命令行参数,可以实现图像标准化和DenseCRF后处理,提高分割精度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这部分代码很坑,原作者代码里若不更改命令行参数norm,则会进行两次标准化

import os
import torch
import argparse
import numpy as np
import scipy.misc as misc


from ptsemseg.models import get_model
from ptsemseg.loader import get_loader
from ptsemseg.utils import convert_state_dict

try:
    import pydensecrf.densecrf as dcrf
except:
    print(
        "Failed to import pydensecrf,\
           CRF post-processing will not work"
    )# 导入CRF后处理


def test(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_file_name = os.path.split(args.model_path)[1]# 命令行传参,模型路径
    model_name = model_file_name[: model_file_name.find("_")]

    # Setup image
    print("Read Input Image from : {}".format(args.img_path))# 图片路径
    img = misc.imread(args.img_path)

    data_loader = get_loader(args.dataset)
    loader = data_loader(root=None, is_transform=True, img_norm=args.img_norm, test_mode=True)
    n_classes = loader.n_classes# 获取指定训练集的类别数

    resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp="bicubic")
    # 将图片变形成模型接受的尺寸
    orig_size = img.shape[:-1]# 除了最后一个元素(通道)的切片,返回H*W
    if model_name in ["pspnet", "icnet", "icnetBN"]:
        # uint8 with RGB mode, resize width and height which are odd numbers
        img = misc.imresize(img, (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1))
    else:
        img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]))
        # 个别网络的输入是原size+1

    img = img[:, :, ::-1]# 最后一维逆序读取
    img = img.astype(np.float64)
    img -= loader.mean# 标准化,减去均值
    if args.img_norm:
        img = img.astype(float) / 255.0

    # NHWC -> NCHW
    img = img.transpose(2, 0, 1)
    img = np.expand_dims(img, 0)# 增加一维
    img = torch.from_numpy(img).float()

    # Setup Model
    model_dict = {"arch": model_name}
    model = get_model(model_dict, n_classes, version=args.dataset)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    # 读取了网络结构的名字和对应的参数
    model.load_state_dict(state)
    model.eval()# model.eval() :针对单张图片,不启用 BatchNormalization 和 Dropout
    model.to(device)

    images = img.to(device)
    outputs = model(images)# n张图片*n个class的概率*h*w

    if args.dcrf:
        unary = outputs.data.cpu().numpy()
        unary = np.squeeze(unary, 0)
        unary = -np.log(unary)
        unary = unary.transpose(2, 1, 0)
        w, h, c = unary.shape
        unary = unary.transpose(2, 0, 1).reshape(loader.n_classes, -1)
        unary = np.ascontiguousarray(unary)

        resized_img = np.ascontiguousarray(resized_img)

        d = dcrf.DenseCRF2D(w, h, loader.n_classes)
        d.setUnaryEnergy(unary)
        d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=resized_img, compat=1)

        q = d.inference(50)
        mask = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
        decoded_crf = loader.decode_segmap(np.array(mask, dtype=np.uint8))
        dcrf_path = args.out_path[:-4] + "_drf.png"
        misc.imsave(dcrf_path, decoded_crf)
        print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path))

    pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
    # 输出h*w,从outputs中取了每个像素预测概率最大的那个值和索引位置,
    # 其中outputs.data.max(1)[]中,返回值有两个,第一个是概率最大值组成的矩阵,
    # 第二个是最大值所在维索引组成的矩阵,这里取得是第二个,即[1]
    # outputs.data.max(1)[1].cpu().numpy()返回1*w*h矩阵,squeeze删除维度为1的维
    if model_name in ["pspnet", "icnet", "icnetBN"]:
        pred = pred.astype(np.float32)
        # float32 with F mode, resize back to orig_size
        pred = misc.imresize(pred, orig_size, "nearest", mode="F")

    decoded = loader.decode_segmap(pred)# 得到Mask颜色图
    print("Classes found: ", np.unique(pred))# 得到寻找到的类
    misc.imsave(args.out_path, decoded)
    print("Segmentation Mask Saved at: {}".format(args.out_path))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Params")
    parser.add_argument(
        "--model_path",
        nargs="?",
        type=str,
        default="fcn8s_pascal_1_26.pkl",
        help="Path to the saved model",
    )
    parser.add_argument(
        "--dataset",
        nargs="?",
        type=str,
        default="pascal",
        help="Dataset to use ['pascal, camvid, ade20k etc']",
    )

    parser.add_argument(
        "--img_norm",
        dest="img_norm",
        action="store_true",
        help="Enable input image scales normalization [0, 1] \
                              | True by default",
    )
    parser.add_argument(
        "--no-img_norm",
        dest="img_norm",
        action="store_false",
        help="Disable input image scales normalization [0, 1] |\
                              True by default",
    )
    parser.set_defaults(img_norm=True)

    parser.add_argument(
        "--dcrf",
        dest="dcrf",
        action="store_true",
        help="Enable DenseCRF based post-processing | \
                              False by default",
    )
    parser.add_argument(
        "--no-dcrf",
        dest="dcrf",
        action="store_false",
        help="Disable DenseCRF based post-processing | \
                              False by default",
    )
    parser.set_defaults(dcrf=False)

    parser.add_argument(
        "--img_path", nargs="?", type=str, default=None, help="Path of the input image"
    )
    parser.add_argument(
        "--out_path", nargs="?", type=str, default=None, help="Path of the output segmap"
    )
    args = parser.parse_args()
    test(args)

 

### 解决方案 当遇到 `ImportError` 或者类似的模块导入错误时,通常可以通过调整依赖库的版本来解决问题。以下是针对此问题的具体解决方案: #### 方法一:降级 Ultralytics 版本 如果问题是由于 Ultralytics 的高版本引起的,则可以尝试将其版本降低到兼容版本。例如,在引用中提到过的情况表明,版本 8.0.x 是较为稳定的选项之一。 通过以下命令卸载现有版本并安装指定版本: ```bash pip uninstall ultralytics -y pip install ultralytics==8.0.0 ``` 这种方法适用于大多数由版本不匹配引起的问题[^1]。 #### 方法二:手动修改代码逻辑 对于特定情况下无法直接更改环境配置的情形,可以选择修改源码中的冲突部分。例如,如果错误发生在 `utils/general.py` 文件内的第 39 行,涉及 `check_requirements` 函数调用失败,则可以直接移除或替换相关代码片段。 具体操作步骤如下: 1. 定位至项目目录下的 `utils/general.py` 文件; 2. 删除引发异常的相关行(如 `from ultralytics.yolo.utils.checks import check_requirements`)[^4]。 需要注意的是,这种做法可能会影响程序功能的一致性和稳定性,因此仅建议作为临时措施使用。 #### 方法三:更新 Python 和相关工具链 有时,即使解决了上述两个层面的问题仍会存在其他潜在隐患。此时可考虑升级整个开发平台的基础组件——包括但不限于解释器本身及其周边生态插件等要素。比如更换成更高支持度的新版Python解释器(推荐至少为3.9以上),以及同步刷新pip管理工具状态(`python -m pip install --upgrade pip`)。 --- ### 示例代码修正后的形式 假设原始测试脚本路径位于 `D:\graduate\MGN-pytorch\MGN-pytorch\test.py` 中,并且其中包含了有问题的语句结构;那么经过适当改造之后应该像这样呈现出来: ```python try: from ultralytics.hub import checks # 尝试加载官方提供的checks函数定义 except ImportError as e: print(f"Warning: {e}. Falling back to custom implementation.") def checks(): """自定义实现替代缺失的功能""" pass if __name__ == "__main__": result = checks() print(result) ``` 这里采用 try-except 结构捕获可能出现的异常状况,并给出相应的提示信息以便后续排查工作开展得更加顺利[^2]。 --- ### 注意事项 尽管采取了这些补救手段,但如果频繁遭遇类似困境的话还是有必要深入探究根本原因所在。可能是所选用框架尚未完全适配当前操作系统特性所致,也有可能是因为第三方扩展包之间相互干扰造成的结果。总之保持良好的文档记录习惯有助于快速定位故障源头。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值