kmeans以及kmeans++聚类生成anchors

本文详细介绍了如何使用Python实现KMeans和KMeans++两种聚类算法,用于YOLO目标检测框架中 anchors 的生成。通过解析VOC格式的XML数据,计算并输出适合模型训练的先验框。同时,提供了数据集格式转换方法,以及评估聚类效果的平均IOU计算。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用yolo系列通常需要通过聚类算法生成anchors,本文给出kmeans以及kmeans++的python实现。
数据格式为VOC的xml文件

若数据集不是voc格式,比如coco格式或者txt格式,可以采用如下方法进行数据集格式转换转为voc再进行聚类计算先验框:
https://blog.youkuaiyun.com/Arcofcosmos/article/details/119983270?spm=1001.2014.3001.5502

kmeans聚类

'''
Author: TuZhou
Version: 1.0
Date: 2021-08-22 18:39:19
LastEditTime: 2021-09-12 17:14:13
LastEditors: TuZhou
Description: 
FilePath: \My_Mobile_Yolo2\utils\kmeans_for_anchors.py
'''
#-------------------------------------------------------------------------------------------------#
#   kmeans虽然会对数据集中的框进行聚类,但是很多数据集由于框的大小相近,聚类出来的9个框相差不大,
#   这样的框反而不利于模型的训练。因为不同的特征层适合不同大小的先验框,越浅的特征层适合越大的先验框
#   原始网络的先验框已经按大中小比例分配好了,不进行聚类也会有非常好的效果。
#-------------------------------------------------------------------------------------------------#
import glob
import xml.etree.ElementTree as ET

import numpy as np

def cas_iou(box,cluster):
    x = np.minimum(cluster[:,0],box[0])
    y = np.minimum(cluster[:,1],box[1])

    intersection = x * y
    area1 = box[0] * box[1]

    area2 = cluster[:,0] * cluster[:,1]
    iou = intersection / (area1 + area2 -intersection)

    return iou

def avg_iou(box,cluster):
    return np.mean([np.max(cas_iou(box[i],cluster)) for i in range(box.shape[0])])

def kmeans(box,k):
    # 取出一共有多少框
    row = box.shape[0]
    
    # 每个框各个点的位置
    distance = np.empty((row,k))
    
    # 最后的聚类位置
    last_clu = np.zeros((row,))

    np.random.seed()

    # 随机选5个当聚类中心
    cluster = box[np.random.choice(row,k,replace = False)]
    # cluster = random.sample(row, k)
    while True:
        # 计算每一行距离五个点的iou情况。
        for i in range(row):
            distance[i] = 1 - cas_iou(box[i],cluster)
        
        # 取出最小点
        near = np.argmin(distance,axis=1)

        if (last_clu == near).all():
            break
        
        # 求每一个类的中位点
        for j in range(k):
            cluster[j] = np.median(
                box[near == j],axis=0)

        last_clu = near

    return cluster

def load_data(path):
    data = []
    # 对于每一个xml都寻找box
    for xml_file in glob.glob('{}/*xml'.format(path)):
        tree = ET.parse(xml_file)
        height = int(tree.findtext('./size/height'))
        width = int(tree.findtext('./size/width'))
        if height<=0 or width<=0:
            continue
        
        # 对于每一个目标都获得它的宽高
        for obj in tree.iter('object'):
            xmin = int(float(obj.findtext('bndbox/xmin'))) / width
            ymin = int(float(obj.findtext('bndbox/ymin')
评论 58
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值