K-means算法是无监督的聚类算法,是将一个未标记的数据集聚类成不同的组
实现步骤如下
- 1.根据数据集随机选择K个点作为聚类中心 (cluster centroids)
- 2.对于数据集中的每一个数据,找出与各个聚类中心的距离最小值,将其归为那个类
- 3.计算每一个聚类中心的数据的平均值,将聚类中心移动到对应平均值处
- 4.重复2和3步骤直至中心点不再变化
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
# import seaborn as sns; sns.set()
%matplotlib inline # %matplotlib显示动态图片(以弹出窗口的方式
一. 通过二维数据集直观显示K-means
================= Part 1: Find Closest Centroids ====================
X = loadmat('ex7data2.mat')['X']
plt.scatter(X[:,0], X[:, 1])
K = 3
def kMeansInitCentroids(data, K):
# 随机生成聚类中心
random_index = np.random.randint(0, len(data), K) # np.random.shuffle(data) 不能用着shuffle,这会打乱原来的数组,导致图像压缩失败
return data[random_index]
def findClosestCentroids(X, centroids, K):
idx = np.zeros((len(X),1))
for i in range(len(X)):
T = np.vstack([X[i,:] for x in range(K)])
idx[i] = np.argmin(np.linalg.norm(T - centroids, axis=1)) # 点到各个聚类中心的最小值的索引
idx = idx.astype(int)
return idx
===================== Part 2: Compute Means =========================
def computeCentroids(X, idx, K):
# 更新聚类中心值
centroids = np.zeros((K, X.shape[1]))
for i in range(K):
index = np.where(idx==i)[0]
centroids[i] = X[index].mean(axis=0)
return centroids
centroids = computeCentroids(X, idx, K)
print("when initial_centroids is np.array([[3,3], [6,2], [8,5]]), \n the computed centroids \
is \n {},\n (it should be \n [ 2.428301 3.157924 ]\n [ 5.813503 2.633656 ]\n [ 7.119387 3.616684 ]\n)".format(str(centroids)))
=================== Part 3: K-Means Clustering ======================
%matplotlib notebook
#得出最后的一张图(不写,只能一张一张呈现,线条分散),但不是连续动态图 # %matplotlib弹出框画连续动态图
def plotDataPoints(X, idx):
plt.ion() # 连续动态图
# ax = plt.gca()
palette = np.array(['y', 'r', 'b'] * len(X)).reshape(-1,3) # 将第一个聚类中心定义为黄色,第二个聚类中心定义为红色,第三个聚类中心定义为蓝色
colors = palette[np.arange(len(X)), idx.ravel()] # 将每个聚类中心所拥有的点染成对应颜色
plt.scatter(X[:, 0], X[:, 1], c=colors, s=5, alpha=0.5)
plt.pause(1)
def plotProgresskMeans(X, centroids, prev, idx, K):
plotDataPoints(X, idx)
plt.plot( np.vstack((prev[:,0], centroids[:,0])), np.vstack((prev[:,1], centroids[:,1])))
plt.scatter(np.vstack((prev[:,0], centroids[:,0])), np.vstack((prev[:,1], centroids[:,1])), marker="x")
def runkMeans(X, centroids, max_iter, plot=True):
previous_centroids = centroids
for i in range(max_iter):
idx = findClosestCentroids(X, centroids, K) # 执行步骤2,执行步骤3
if plot:
plotProgresskMeans(X, centroids, previous_centroids, idx, K)
previous_centroids = centroids
centroids = computeCentroids(X, idx, K)
return centroids, idx
initial_centroids = np.array([[3,3], [6,2], [8,5]])
[centroids, idx] = runkMeans(X, initial_centroids, 10)
二. 图片压缩
import cv2
birds = cv2.imread("bird_small.png")
birds_rgb = cv2.cvtColor(birds, cv2.COLOR_BGR2RGB) # cv2读入的颜色需要转一下变成rgb
birds_rgb.shape # (128, 128, 3)
birds_rgb[:1,:3,:]
array([[[219, 180, 103],
[230, 185, 116],
[226, 186, 110]]], dtype=uint8)
上述说明该图片由128128个像素组成,3代表一个像素由RGB三通道来表示,每个通道由256级灰度表示(数字0-255),可用8个二进制位表示,也称为24位色,可表达2^24=16777216种颜色*
len(set([tuple(x) for x in birds_rgb.reshape(-1, 3)])) # 该图像共有13930总颜色
13930
raw_birds = birds_rgb / 255 # 图像归一化,plt.imshow能显示0-255的整数或者0-1之间的浮点数。使用K-means后肯定是浮点数
raw_birds_A = raw_birds.reshape(-1, 3)
K = 16
initial_centroids = kMeansInitCentroids(raw_birds_A, K)
[centroids, idx] = runkMeans(raw_birds_A, initial_centroids, max_iter=10, plot=False )
compressed_birds = centroids[idx,:].reshape(128, 128, 3)
fig, ax = plt.subplots(1, 2)
ax[0].imshow(raw_birds)
ax[1].imshow(compressed_birds)
ax[0].set_title('raw_birds')
ax[1].set_title('compressed_birds')
使用sklearn计算
from sklearn.cluster import KMeans
km = KMeans(n_clusters=16).fit(raw_birds_A)
compress_A = km.cluster_centers_[km.labels_].reshape(128,128,3)
fig, ax = plt.subplots(1, 2)
ax[0].imshow(raw_birds)
ax[1].imshow(compress_A)
ax[0].set_title('raw_birds')
ax[1].set_title('compressed_birds')