import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def LoadData(filename):
data=pd.read_csv(filename)#dataframe类型
return data.values[:,0:4]#二维数组, n1:n2 取从n1开始,不包括n2的行或列。只取值,不取标签
#初始化k个centroid
def randCent(dataSet, k):#分k类
n = np.shape(dataSet)[1]#它的功能是读取矩阵的长度,二维的返回(行数,列数)取列数
centroids = np.mat(np.zeros((k,n))) # 每个质心有n个坐标值,总共要k个质心 zeros的第一个参数是(k,n) k行n列(3*4)
for j in range(n):#初始化n个维度。每次是一列一起操作的
minJ = min(dataSet[:,j])
maxJ = max(dataSet[:,j])
rangeJ = float(maxJ - minJ)
centroids[:,j] = minJ + rangeJ * np.random.rand(k, 1)
return centroids
#计算欧几里得距离
def distEclud(vecA, vecB):
return np.sqrt(np.sum(np.power(vecA - vecB, 2))) # 求两个向量之间的距离
def plot_centroid(data, clusterAssignment):
m = np.shape(data)[0]
for i in range(m):
if clusterAssignment[i,0]==0:
plt.scatter(data[i, 0], data[i, 1], c = "red", marker='o')
elif clusterAssignment[i,0]==1:
plt.scatter(data[i, 0], data[i, 1], c = "green", marker='*')
else :
plt.scatter(data[i, 0], data[i, 1], c = "blue", marker='+')
plt.xlabel('petal length')
plt.ylabel('petal width')
plt.show()
def kmeans(data, k):#k类
centroids = randCent(data,k)
#print (centroids)
m = np.shape(data)[0]#行数,一共有多少数据
clusterChanged = True
clusterAssignment = np.zeros((m,2)) # 存放每个样本点的所属类别和距离
#当所有点的所属点的类别状态不再变化之后,停止更新
while clusterChanged:
clusterChanged = False
for i in range(m):
minDist = np.inf#初始化为无限大的数
minIndex = 0
#计算当前样本点与centroid的距离
for j in range(k):#找到离哪个点最近
dist = distEclud(data[i,:],centroids[j,:])
if dist < minDist:
minDist = dist
minIndex = j
if clusterAssignment[i,0] != minIndex:
clusterAssignment[i,0] = minIndex
clusterAssignment[i,1] = minDist
clusterChanged = True
print ("-----------------")
#print (centroids)
#print (clusterAssignment)
# 重新计算中心点
for cent in range(k):
ptsInClust = data[np.nonzero(clusterAssignment[:,0] == cent)[0]] # 取第一列等于cent的所有行,nonzero返回非零的索引
#print (np.nonzero(clusterAssignment[:,0] == cent)[0])
centroids[cent,:] = np.mean(ptsInClust, axis = 0) # 算出这些数据的中心点
return centroids, clusterAssignment
if __name__=="__main__":
data = LoadData('iris.csv')
c, a = kmeans(data,3)
print (c)
plot_centroid(data[:,2:4],a)
聚类
最新推荐文章于 2024-10-25 14:06:52 发布