聚类代码一
def k_means(boxes, cluster_num): # median not mean
box_number = boxes.shape[0]
last_nearest = np.zeros((box_number,))
clusters = boxes[np.random.choice(box_number, cluster_num, replace=False)]
while True:
distances = 1 - wh_iou(boxes, clusters)
current_nearest = np.argmin(distances, axis=1)
if (last_nearest == current_nearest).all():
break
center_sum = np.zeros(shape=(clusters))
for i in range(boxes.shape[0]):
center_sum[current_nearest[i]] += boxes[i]
for cluster in range(cluster_num):
clusters[cluster] = center_sum[cluster] / np.sum(current_nearest == cluster)
last_nearest = current_nearest
return clusters
聚类代码二
相比代码一更加简洁。
def k_means(boxes, cluster_num):
box_number = boxes.shape[0]
last_nearest = np.zeros((box_number,))
clusters = boxes[np.random.choice(box_number, cluster_num, replace=False)]
while True:
distances = 1 - wh_iou(boxes, clusters)
current_nearest = np.argmin(distances, axis=1)
if (last_nearest == current_nearest).all():
break
for cluster in range(cluster_num):
clusters[cluster] = np.mean(boxes[current_nearest == cluster], axis=0)
last_nearest = current_nearest
return clusters