一、工作原理
- 对于每个实例,该算法都会计算在它一小段距离内 ε \varepsilon ε 内有多少个实例。该区域称为实例的 ε − \varepsilon- ε− 邻域。
- 如果一个实例在其 ε \varepsilon ε 邻域内至少包含 min_samples 个实例(包含自身),则该实例为核心实例。
- 核心实例附近的所有实例都属于同一集群。这个邻域可能包括其他核心实例。因此,一长串相邻的核心实例形成一个集群。
- 任何不是核心实例且邻居中没有核心实例的实例都被视为异常
二、参数
sklearn中参数详解:详解
两个重要参数:
- eps: ε \varepsilon ε 的大小
- min_samples : 核心实例中至少包含的实例个数
三、变量
sklearn.dataset中的make_moons()函数:链接
make_circles()函数与make_moons()函数相似
from sklearn.cluster import DBSCAN
from sklearn.datasets import make_moons,make_circles
X,y = make_moons(n_samples = 1000,noise = 0.05)
make_moons()生成数据为:
make_circles()生成数据为:
DBSCAN对象的变量:
- labels_ : 每个实例的集群标签。异常实例的集群标签为-1
- core_sample_indices_ : 包含每个核心实例的索引
- components_ : 可以得到核心实例本身
X,y = make_moons(n_samples = 100 ,noise = 0.1)
dbscan = DBSCAN(eps = 0.2,min_samples =2).fit(X)
labels_ = dbscan.labels_
print("标签为:{}".format(labels_))
len_ = len(dbscan.core_sample_indices_)
print("核心实例个数为:{}".format(len_))
data = dbscan.components_
print("核心实例:{}".format(data))
结果为:
四、代码
1.难点
- matplotlib折线图:详解
1. 导包
from sklearn.cluster import DBSCAN
from sklearn.datasets import make_moons,make_circles
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
2. 数据集
X,y = make_moons(n_samples = 1000 ,noise = 0.1)
X2,y2 = make_circles(n_samples = 1000,noise = 0.025)
3. 函数
#参数:eps_为邻域大小,min_sample_为核心实例邻域中最小实例数目,X_为数据集;本次用了两个数据集
def DBscan(eps_,min_sample_,X_):
#创建模型并训练
dbscan = DBSCAN(eps = eps_,min_samples =min_sample_).fit(X_)
#绘制散点图时的c参数,大小和点的个数一样,即dbscan的变量core_sample_indices的len:len(dbscan.core_sample_indices_)
mask = np.arange(len(dbscan.core_sample_indices_))
#mask内为每核心实例的集群索引,异常实例的集群索引为-1;
#在dbscan的core_sample_indices和components变量中并没有出现非核心实例,非核心实例和异常实例直接被算法过滤掉了。
for idx,i in enumerate(dbscan.core_sample_indices_):
mask[idx] = dbscan.labels_[i]
#画散点图
plt.scatter(dbscan.components_[:,0],dbscan.components_[:,1],c = mask)
#标题
plt.title("eps = {},min_samples = {}".format(eps_,min_sample_))
4. 调用函数
对每个数据集分别调试两组参数,第二组参数效果较好,也就是第三列
plt.figure(figsize = (12,8))
plt.subplot(231)
plt.scatter(X[:,0],X[:,1],c = y)
plt.title("Original data")
plt.subplot(232)
DBscan(0.05,5,X)
plt.subplot(233)
DBscan(0.1,5,X)
plt.subplot(234)
plt.scatter(X2[:,0],X2[:,1],c = y2)
plt.title("Original data2")
plt.subplot(235)
DBscan(0.05,7,X2)
plt.subplot(236)
DBscan(0.07,7,X2)
plt.show()
结果为: