import numpy as np
import pickle
with open('clusters.pkl', 'rb') as f:
clusters = pickle.load(f)
data_list = []
for index_data, datas in enumerate(clusters):
dk = []
for i, data_cluster in enumerate(datas):
for d in data_cluster.data:
dk.append(d)
data_list.append(dk)
def DBSCAN(points, eps, min_points):
"""
DBSCAN算法实现
:param points: 数据点数组,每行表示一个数据点
:param eps: 半径
:param min_points: 最小点数
:return: 聚类标签数组
"""
labels = [-1] * len(points)
core_points = np.zeros(len(points), dtype=bool)
for i in range(len(points)):
count = np.sum(np.linalg.norm(points - points[i], axis=1) <= eps)
if count >= min_points:
core_points[i] = True
cluster_id = 0
for i in range(len(points)):
if labels[i] != -1:
continue
if core_points[i]:
labels[i] = cluster_id
expand_cluster(points, labels, core_points, i, cluster_id, eps, min_points)
cluster_id += 1
return labels
def expand_cluster(points, labels, core_points, point_id, cluster_id, eps, min_points):
"""
扩展当前点的聚类
:param points: 数据点数组,每行表示一个数据点
:param labels: 聚类标签数组
:param core_points: 点的核心性数组
:param point_id: 当前点的索引
:param cluster_id: 当前聚类的标签
:param eps: 半径
:param min_points: 最小点数
:return: None
"""
neighbor_ids = np.where(np.linalg.norm(points - points[point_id], axis=1) <= eps)[0]
if not core_points[point_id]:
labels[point_id] = cluster_id
return
for i in neighbor_ids:
if labels[i] == -1:
labels[i] = cluster_id
if core_points[i]:
expand_cluster(points, labels, core_points, i, cluster_id, eps, min_points)
data_value = np.array(data_list[0])
l = DBSCAN(data_value, 0.5, 5)
import matplotlib.pyplot as plt
colors = ['red', 'blue', 'green', 'orange', 'purple','black']
for i in range(len(l)):
if l[i] != -1:
plt.scatter(data_value[i][0], data_value[i][1], c=colors[l[i]])
plt.show()
结果图
