谱聚类(spectral clustering) 点击打开链接
import matplotlib.pyplot as plt
import numpy as np
from numpy import linalg as LA
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.preprocessing import normalize
import tensorflow as tf
from mpl_toolkits.mplot3d import Axes3D
data = []
cty_data = []
def data_p(fliename):
fr = open(fliename)
tid = 0
cid = 0
for line in fr.readlines():
tline = line.strip().split(';')
data.append(tline)
cn = tline[0]
if tid==0:
cty_data.append([tid])
#print(cty_data)
#print(cty_data[cid][0])
elif cn == data[cty_data[cid][0]][0]:
cty_data[cid].append(tid)
else:
cid += 1
cty_data.append([tid])
tid += 1
data_p('C:\\Users\\imac\\Desktop\\2018\\bigdata\\py\\wealth1951.txt')
cty_data = np.asarray(cty_data)
#print(cty_data)
data = np.asarray(data)
#print(data[cty_data[:,:],0])
cdata = []
def getdata():
for i in range(len(cty_data)): #国家
itm = data[cty_data[i,:],3:6]
#print(itm)
itm = np.asarray(itm)
itmm = []
for j in range(3): #维数
for k in range(len(itm)-1,0,-1): #年份从后往前
itm[k,j] = (float(itm[k,j])-float(itm[k-1,j]))/float(itm[k-1,j])
#变化率
#print(itm[1:len(itm),j])
itmm.append(np.mean(np.asarray(itm[1:len(itm),j],dtype=np.float32)))
#每个国家变化率均值
#print(itm)
cdata.append(itmm)
getdata()
cdata = np.asarray(cdata,dtype=np.float32)
cdata = normalize(cdata)
print(cdata)
def similarity_function(points):
res = rbf_kernel(points)
#print(res)
for i in range(len(res)):
res[i,i] = 0
return res
def spectral_clustering(points,k):
W = similarity_function(points)
D = np.diag(np.sum(W,axis=1))
#print(W)
Dn = np.sqrt(LA.inv(D))
#L=Dn*(D-W)Dn= I-Dn*W*Dn
L = np.eye(len(points)) - np.dot(Dn, np.dot(W,Dn))
eigvals, eigvecs = LA.eig(L)
indices = np.argsort(eigvals)[:k]
k_smallest_eigenvectors = normalize(eigvecs[:,indices])
return KMeans(n_clusters=k).fit_predict(k_smallest_eigenvectors)
labels = spectral_clustering(cdata, 5)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(cdata[:, 0], cdata[:, 1],cdata[:,2], c=labels)
plt.show()
print(labels)