YoloV5 模型自定义评估-误报、漏报、错报

本文指导如何利用YoloV5模型的val.py文件,通过调整参数获取模型的准确率、误报率和漏报率。通过对比预测与标注,提供了详细的方法来统计各类错误,帮助领导快速理解模型性能。
部署运行你感兴趣的模型镜像

YoloV5模型训练成功后,可以通过自带的val.py文件进行评估分析,其提供mAp、Iou以及混淆矩阵等,很好,但是……领导不认可……/(ㄒoㄒ)/~~。领导要的是最直观的东西,比如这个模型识别目标的准确率,还有误报率等……。那么,领导的要求就是我们开发的方向:

为了得到准确率以及误报、漏报、错报的情况,需要使用模型检测已经标注过的样本,将检测结果与标注进行比对。

YoloV5的val.py文件已经实现了大部分功能,可以直接拿来改造:

首先,原文件中对预测结果的NMS处理函数调用如下:

out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)

val.py本来的一个作用就是与train.py配合,在模型迭代训练过程中提供评估,但是我们这里不需要处理太多可能,只要置信度最高的就好了,修改如下:

out = non_max_suppression(out, conf_thres, iou_thres, multi_label=False, agnostic=single_cls)

然后将检测结果与标注结果做对比,原代码中是在process_batch这个函数中进行处理,计算所有的检测框和标注框的IoU,根据检测类比和标注类别再进行处理。这里做一下改造,代码如下:

def process_batch(batch_i, detections, labels, iouv):
    """
    Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
    Arguments:
        detections (Array[N, 6]), x1, y1, x2, y2, conf, class
        labels (Array[M, 5]), class, x1, y1, x2, y2
    Returns:
        correct (Array[N, 10]), for 10 IoU levels
    """

    # 正确:0   误报: 1   漏报: 2    错报: 3
    labelNum = labels.shape[0]
    detectNum = detections.shape[0]
    if labelNum == 0 and detectNum == 0:
        return 0
    elif labelNum == 0: #存在误报
        results = torch.zeros(detectNum, 6 , dtype=torch.float, device=iouv.device)
        results[:,0] = batch_i
        results[:,1] = 1
        results[:,2] = detections[:, 5]   #检测类别
        results[:,3] = -1.0                 #标注类别
        #results[:,4] = 0.0                 #检测框和标注框的IoU
        results[:,5] = detections[:, 4]   #检测置信度
        return results
    elif detectNum == 0:  #存在漏报
        results = torch.zeros(labelNum, 6, dtype=torch.float, device=iouv.device)
        results[:,0] = batch_i
        results[:,1] = 2
        results[:,2] = -1.0                #检测类别
        results[:,3] = labels[:, 0]      #标注类别
        #results[:,4] = 0.0                 #检测框和标注框的IoU
        #results[:,5] = 0.0                 #检测置信度
        return results
    else:
        tempresults = torch.zeros(0, 6, dtype=torch.float, device=iouv.device)
        indexes = torch.zeros(1, detectNum, dtype=torch.float, device=iouv.device)
        validNum = 0
        i=0 
        for label in labels:
            j=0
            bMatchLabel = False
            for detect in detections:
                iou = singlebox_iou(label[1:], detect[:4])
                if iou > 0.5:
                    bMatchLabel = True
                    indexes[:,j] = 1.0
                    validNum += 1
                    result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
                    result[:,0] = batch_i
                    result[:,1] = 0 if detect[5] == label[0] else 3
                    result[:,2] = detect[5]     #检测类别
                    result[:,3] = label[0]      #标注类别
                    result[:,4] = iou           #检测框和标注框的IoU
                    result[:,5] = detect[4]     #检测置信度
                    #tempresults.append(result)
                    tempresults = torch.cat((tempresults, result), 0)
                j += 1
            i += 1
            if not bMatchLabel:  #如果有标注狂没有匹配到检测框的为漏报
                validNum += 1
                result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
                result[:,0] = batch_i        
                result[:,1] = 2
                result[:,2] = -1.0          #检测类别
                result[:,3] = label[0]      #标注类别
                #tempresults.append(result)
                tempresults = torch.cat((tempresults, result), 0)

        #如果有检测框没有匹配到标注框的为误报
        noMatchDect = indexes[0][(indexes[0] == 0).nonzero()]
        if noMatchDect.shape[0] > 0:
            for index in noMatchDect:
                result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
                i = int(index[0])
                result[:,0] = batch_i
                result[:,1] = 1
                result[:,2] = detections[i][5]   #检测类别
                result[:,3] = -1.0               #标注类别
                result[:,5] = detections[i][4]   #检测置信度
                #tempresults.append(result)
                tempresults = torch.cat((tempresults, result), 0)

        return tempresults

还得解释一下这里对误报、漏报、错报的定义:

误报:没有标注,但是检测出目标

漏报:有标注,但是没检测出目标

错报:有标注,和检测目标的类别不一样

所有样本处理完后进行统计,统计每个类别的误报、漏报数量,以及错报(其实混淆矩阵就是这个功能,只不过这里用的是数量),代码如下:

def analyse(totalResults, names, save_dir):
    classnames =[]
    for i in range(len(names)):
        classnames.append(str(names[i]))
        
    #误报
    misIndexes = torch.where(totalResults[:,1] == 1.0)
    misItems = totalResults[misIndexes]
    misStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
    line = "误报- "
    for i in range(len(names)):
        n = len(torch.where(misItems[:,2] == i)[0])
        misStats[:,i] = n
        line += names[i] + ":" + str(n) +" "

    print(line + "\n")

    #plt.rcParams['font.sans-serif']=['SimHei']
    #plt.style.use('ggplot')
    #plt.title("误报")
    #plt.xlabel("异物类别")
    #plt.ylabel("数量")
    #plt.bar(classnames, misStats[0].tolist(), width=1.2, color='blue')
    #plt.savefig(str(save_dir) + "/evaluates/误报.jpg")

    #漏报
    failIndexes = torch.where(totalResults[:,1] == 2.0)
    failItems = totalResults[failIndexes]
    failStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
    line = "漏报- "
    for i in range(len(names)):
        n = len(torch.where(failItems[:,3] == i)[0])
        failStats[:,i] = n
        line += names[i] + ":" + str(n) +" "

    #plt.title("漏报")
    #plt.xlabel("异物类别")
    #plt.ylabel("数量")
    #plt.bar(classnames, failStats[0].tolist(), width=1.2, color='blue')
    #plt.savefig(str(save_dir) + "/evaluates/漏报.jpg")

    print(line + "\n")

    #错报
    errorIndexes  = torch.where(totalResults[:,1] == 3.0)
    errorItems = totalResults[errorIndexes]
    errorStats = torch.zeros(len(names), len(names), dtype=torch.int, device=totalResults.device)
    for item in errorItems:
        errorStats[int(item[2]), int(item[3])] += 1

    print("错报- \n")
    print(errorStats)

您可能感兴趣的与本文相关的镜像

Yolo-v5

Yolo-v5

Yolo

YOLO(You Only Look Once)是一种流行的物体检测和图像分割模型,由华盛顿大学的Joseph Redmon 和Ali Farhadi 开发。 YOLO 于2015 年推出,因其高速和高精度而广受欢迎

<think>嗯,用户之前问过关于mbedtls_rsa_pkcs1_decrypt函数报错的问题,现在又遇到了错误码-0x4080。需要先确定这个错误码对应的宏定义是什么。根据之前的回答,用户可能已经了解了一些常见的错误码,比如-0x4100是BAD_INPUT_DATA,-0x4200是INVALID_PADDING等。但-0x4080不在之前提供的错误代码映射表中,所以需要查找mbedtls的文档或源代码来确认。 首先,回忆一下mbedtls的错误码结构。mbedtls使用16进制错误码,高位表示模块,低位表示具体错误。比如RSA模块的错误码通常以0x4开头,所以-0x4080可能属于其他模块。比如,0x4080对应的正数是0xBF80(补码),但可能需要查看具体的定义。 然后,检查mbedtls的错误码头文件,例如error.h。可能发现MBEDTLS_ERR_PK_PASSWORD_REQUIRED对应的值是-0x3A80,而MBEDTLS_ERR_ASN1_OUT_OF_DATA是-0x0060,所以-0x4080可能属于其他模块,比如CIPHER模块或者SSL模块。比如,MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE是-0x6080,但不确定。 或者,可能是用户记错了错误码,比如-0x4280是INVALID_PADDING,而-0x4080可能接近这个值。需要进一步确认。如果用户提供的错误码确实正确,可能需要查阅最新的mbedtls文档或代码库。 另外,用户可能在使用其他函数时遇到了这个错误,而不仅仅是rsa_pkcs1_decrypt。比如,在初始化RSA上下文、解析密钥或进行其他操作时出现的错误。需要用户提供更多上下文,比如代码片段,错误发生的具体位置,以及他们正在执行的操作步骤。 考虑到用户之前的问题是关于RSA解密的,可能这个错误码与RSA模块相关,但之前的映射表没有覆盖到。可能需要检查mbedtls的更新日志或源代码,看看是否有新增的错误码。例如,在某些版本中,可能引入了新的错误码,如MBEDTLS_ERR_RSA_KEY_GEN_FAILED或其他。 如果无法立即找到对应的错误码,建议用户使用mbedtls_strerror函数将错误码转换为字符串描述,这样可以更准确地确定错误原因。例如: char error_buf[256]; mbedtls_strerror(ret, error_buf, sizeof(error_buf)); printf("错误描述: %s\n", error_buf); 此外,检查用户是否正确地初始化了所有结构体,尤其是RSA上下文,是否正确导入了私钥,或者在使用过程中是否修改了某些参数导致不一致。例如,私钥未正确生成或导入,或者在解密时使用了错误的填充模式。 另外,考虑内存分配问题,比如缓冲区溢出或未正确分配内存空间,导致函数执行失败。需要检查输入输出缓冲区的长度是否符合要求,特别是output_len是否足够大以容纳解密后的数据。 如果用户在使用PKCS1解密,可能填充验证失败,但错误码不同。或者,可能在解密过程中遇到了其他问题,如不支持的密钥大小或格式。例如,使用了一个不支持的RSA密钥长度,或者密钥数据损坏。 总结来说,用户需要做的是: 1. 确认错误码的具体宏定义,使用mbedtls_strerror获取描述。 2. 检查RSA上下文的初始化过程,确保私钥正确导入并验证。 3. 验证输入输出参数是否符合函数要求,特别是长度和指针是否有效。 4. 如果可能,提供更多代码片段和操作步骤,以便更精准地定位问题。</think>--- ### 一、错误码解析 **错误码 `-0x4080`** 对应 mbedtls 的宏定义为 **`MBEDTLS_ERR_RSA_KEY_CHECK_FAILED`**,其核心含义为 **RSA 密钥校验失败**。此错误通常发生在以下场景: ```c // 触发场景示例:密钥参数数学关系不成立 if (mbedtls_rsa_check_privkey(&rsa) != 0) { // 此处可能返回 -0x4080 } ``` --- ### 二、关键原因分析 #### 1. **密钥参数不匹配** - **私钥参数错误**:私钥的模数 `N`、私钥指数 `D`、素数因子 `P/Q` 等参数未通过数学关系验证(如 $D \times E \equiv 1 \mod \text{lcm}(P-1, Q-1)$)。 - **常见操作失误**: - 手动拼接密钥参数时遗漏关键字段 - 从 PEM 文件解析密钥时未正确处理 Base64 编码 #### 2. **密钥生成缺陷** - 使用 `mbedtls_rsa_gen_key()` 生成密钥时,若熵源不足或质数生成失败,可能导致密钥不完整: ```c mbedtls_rsa_gen_key(&rsa, mbedtls_ctr_drbg_random, &ctr_drbg, 2048, 65537); if (mbedtls_rsa_check_privkey(&rsa) != 0) { // 生成操作后未正确验证密钥 } ``` --- ### 三、调试步骤 #### 1. **密钥完整性校验** ```c // 强制验证密钥参数 int ret = mbedtls_rsa_validate_params( &rsa.P, &rsa.Q, &rsa.D, &rsa.E, &rsa.N, MBEDTLS_RSA_PKCS_V15); if (ret != 0) { printf("密钥参数校验失败: 0x%04X\n", -ret); } ``` #### 2. **密钥数据导出验证** ```c // 导出模数 N 并打印(验证是否符合预期) unsigned char n_buf[256]; size_t n_len; mbedtls_mpi_write_binary(&rsa.N, n_buf, sizeof(n_buf)); printf("Modulus(N) HEX: "); for (n_len = 0; n_len < sizeof(n_buf); n_len++) { printf("%02X", n_buf[n_len]); } ``` --- ### 四、修复方案 #### 1. **密钥重新生成** ```c // 确保使用正确的熵源和密钥位数 mbedtls_entropy_init(&entropy); mbedtls_ctr_drbg_init(&ctr_drbg); mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0); mbedtls_rsa_init(&rsa, MBEDTLS_RSA_PKCS_V15, 0); mbedtls_rsa_gen_key(&rsa, mbedtls_ctr_drbg_random, &ctr_drbg, 2048, 65537); ``` #### 2. **密钥完整性修复** ```c // 若密钥为外部导入,需验证参数关系 mbedtls_rsa_complete(&rsa); // 自动计算缺失的DP/DQ/QP等参数 if (mbedtls_rsa_check_privkey(&rsa) != 0) { // 提示用户密钥已损坏 } ``` --- ### 五、错误码对照表 | 十六进制错误码 | 宏定义 | 数学验证失败点 | |----------------|--------------------------|-----------------------------| | 0x4080 | RSA_KEY_CHECK_FAILED | $D \times E \not\equiv 1 \mod \phi(N)$ | | | | $P \times Q \neq N$ | 建议在密钥加载后立即调用 `mbedtls_rsa_check_privkey()` 进行预验证。
评论 15
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值