原文来自我的个人博客:http://www.yuanyong.org/blog/python/k-means-using-python
最近在翻译《Programming Computer Vision with Python》第六章Clustering Images,其中用到了k-means,这里根据书中给出的实例对k-means python code做一些解释。关于k-means聚类算法的原理,这里不再赘述,原理可以查阅相关资料。
在给出完整代码之前,我们先来理解两个numpy、scipy两个模块中设计到的两个函数,分别对应的是numpy.vstack()和scipy.cluster.vq()。我们直接看这两个函数的例子:
Example for numpy.vstack()
1 | >>> a = np.array([1, 2, 3]) |
2 | >>> b = np.array([2, 3, 4]) |
输出结果为:
array([[1, 2, 3], [2, 3, 4]])
从这个简单的例子可以看出,np.vstack()这个函数实现connection的作用,即connection(a,b),为了看得更清楚,我们再来看一个这个函数的例子:
1 | >>> a = np.array([[1], [2], [3]]) |
2 | >>> b = np.array([[2], [3], [4]]) |
输出结果这里不给出了,具体可以再python shell上test。好了,现在我们了解了这个函数的作用,我们再来看scipy.cluster.vq()函数的作用,这里也直接给出实例,通过实例解释该函数的作用:
Example for scipy.cluster.vq()
1 | >>> from numpy import array |
2 | >>> from scipy.cluster.vq import vq |
3 | >>> code_book = array([[1.,1.,1.],[2.,2.,2.]]) |
4 | >>> features = array([[ 1.9,2.3,1.7],[ 1.5,2.5,2.2],[ 0.8,0.6,1.7]]) |
5 | >>> vq(features,code_book) |
输出结果为:
(array([1, 1, 0]), array([ 0.43588989, 0.73484692, 0.83066239])),下图解释了该结果的意义,array([1, 1, 0])中的元素表示features中的数据点对应于code_book中离它最近距离的索引,如数据点[1.9, 2.3, 1.7]离code_book中的[2., 2., 2.]最近,该数据点对的对应于code_book中离它最近距离的索引为1,在python中索引值时从0开始的。

当然,对于上面的结果可以用linalg.norm()函数进行验证,验证过程为:
1 | >>> from numpy import array |
2 | >>> from scipy.cluster.vq import vq |
3 | >>> code_book = array([[1.,1.,1.],[2.,2.,2.]]) |
4 | >>> features = array([[ 1.9,2.3,1.7],[ 1.5,2.5,2.2],[ 0.8,0.6,1.7]]) |
5 | >>> vq(features,code_book) |
6 | >>> from numpy import * |
7 | dist = linalg.norm(code_book[1,:] - features[0,:]) |
输出的dist的结果为:dist: 0.43588989435406728
好了,了解完这两个函数,我们可以上完整了演示k-means完整的代码了。
6 | from scipy.cluster.vq import * |
8 | class1 = 1.5 * randn(100,2) |
9 | class2 = randn(100,2) + array([5,5]) |
10 | features = vstack((class1,class2)) |
11 | centroids,variance = kmeans(features,2) |
12 | code,distance = vq(features,centroids) |
14 | ndx = where(code==0)[0] |
15 | plot(features[ndx,0],features[ndx,1],'*') |
16 | ndx = where(code==1)[0] |
17 | plot(features[ndx,0],features[ndx,1],'r.') |
18 | plot(centroids[:,0],centroids[:,1],'go') |
上述代码中先随机生成两类数据,每一类数据是一个100*2的矩阵,centroids是聚类中心,这里聚类中心k=2,并将其作为code_book代用vq(),代码运行结果如下:

上图显示了原数据聚完类后的结果,绿色圆点表示聚类中心。
Hello Python. Enjoy yourself.
Reference:
[1]. http://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.vq.vq.html#scipy.cluster.vq.vq
[2]. http://docs.scipy.org/doc/numpy/reference/generated/numpy.vstack.html