#代码:https://www.cnblogs.com/wsine/p/5180769.html
#证明:https://zhuanlan.zhihu.com/p/149597282
import numpy as np
def loadDataSet(fileName):#读入数据集
dataSet = []
with open(fileName) as fr:
for line in fr.readlines():
curline = line.strip().split(',')#strip()的作用等同于strip(' '),用于剔除空格 # strip v.夺走
fltline = list(map(float, curline))
dataSet.append(fltline)
return dataSet
def kMeans(dataSet, k):
n = np.shape(dataSet)[1] #数据集的列数
centroids = np.mat(np.zeros((k, n)))#用于存储所有质心
for j in range(n):#随机生成质心
minJ = min(dataSet[:, j])#第j列的最小值
maxJ = max(dataSet[:, j])#第j列的最大值
rangeJ = float(maxJ - minJ)
centroids[:, j] = minJ + rangeJ * np.random.rand(k, 1)#在第j列的变化域中随机取一个点
m = np.shape(dataSet)[0] #数据集的行数(个数)
clusterAssment = np.mat(np.zeros((m, 2)))# m*2的零矩阵 用来记录每个点最近的中心和距离?
clusterChanged = True
while clusterChanged:#直到聚类结果不变时再停止
clusterChanged = False
for i in range(m): #寻找每个元素最近的质心 #遍历每个元素
minDist = 9999999.0#极大值
minIndex = -1
for j in range(k):#遍历所有质心
distJI = np.sqrt(np.sum(np.power(centroids[j, :]- dataSet[i, :], 2)))#计算点到质心的距离#欧式距离
if distJI < minDist:
minDist = distJI
minIndex = j
if clusterAssment[i, 0] != minIndex:
clusterChanged = True
clusterAssment[i, :] = minIndex, minDist**2
for cent in range(k): # 更新质心的位置
ptsInClust = dataSet[np.nonzero(clusterAssment[:, 0].A == cent)[0]]
centroids[cent, :] = np.mean(ptsInClust, axis=0)
return centroids, clusterAssment # 返回:类中心,
def plotFeature(dataSet, centroids, clusterAssment):
import matplotlib.pyplot as plt
m = np.shape(centroids)[0] #类别个数
fig = plt.figure()
scatterMarkers = ['s', 'o', '^', '8', 'p', 'd', 'v', 'h', '>', '<']
scatterColors = ['blue', 'green', 'yellow', 'purple', 'orange', 'black', 'brown']
ax = fig.add_subplot(111)
for i in range(m):#逐个类别绘制
markerStyle = scatterMarkers[i % len(scatterMarkers)]#选择颜色
colorSytle = scatterColors[i % len(scatterColors)]#选择符号
ptsInCurCluster = dataSet[np.nonzero(clusterAssment[:, 0].A == i)[0], :]
ax.scatter(ptsInCurCluster[:, 0].flatten().A[0], ptsInCurCluster[:, 1].flatten().A[0], marker=markerStyle, c=colorSytle, s=90)
ax.scatter(centroids[:, 0].flatten().A[0], centroids[:, 1].flatten().A[0], marker='+', c='red', s=300)#用红色加号标记出类中心
plt.show()
if __name__ == '__main__':
dataSet = loadDataSet('788points.txt')#读取数据
dataSet = np.mat(dataSet)
resultCentroids, clustAssing = kMeans(dataSet, 6)#聚类
plotFeature(dataSet, resultCentroids, clustAssing)#绘制
05-09
1629
