最近看一个车道线识别的算法LaneNet,其中用到了mean shift进行聚类,然后研究了一下这个聚类算法,主要是从代码中了解的,简单记录一下自己的理解,防止以后忘记。meanshift code
使用mean shift聚类我们不用预先知道数据需要聚集为几类,算法会自动找出几个cluster。
随机数据
在开始使用mean shift算法之前先随机生成几蔟数据,方便后面验证聚类效果。
from sklearn.datasets import make_blobs
data, label = make_blobs(n_samples=500, centers=5, cluster_std=1.2, random_state=5)
这样就生成500个数据,有5个类别,使用不同颜色显示出来,可以看到有两组数据很接近,后面可以看到算法的聚类效果。
mean shift聚类
1.首先找出可能是中心点的一些坐标,做法就是把所有的数据通过np.round规整为几十类,然后把这几十类中属于每个类的点的个数大于3的保留下来,这样筛选出来大概28组可能的中心点。其实还可以用其他的方法选择中心点,或者把每个数据都当做中心点也可以。
def get_seeds(self, data):
if self.bin_seeding:
binsize = self.band_width
else:
binsize = 1
seed_list = []