机器学习之利用K-Means算法进行图片压缩
1、K-Means算法工作原理
为了更好的进行解释K-Means算法工作原理,这里引用大佬的两张图片,大佬地址:https://blog.youkuaiyun.com/qq_30374549/article/details/81005872。
可以看到上面的图片总体分3个区域,可以说分三类,更专业的说法是分为三个族。K就是指这个3,我们的任务是找到这3个区域的中心点。怎么找呢?
上面的图是找出3个中心点的过程:我们先随机生成3个点(假如是A,B,C点),先循环图中所有点利用两点间距离公式分别与点A、B、C算出距离,找出距离最短的进行归类(如D是其中一点,AD=3,BD=2,CD=7,则D属于B类),再将三类的点分别求平均得到A1,B1,C1。一直如此循环直到An,Bn,Cn的值较为稳定。
那要怎么实现呢?我用三种K-Means算法实现图片压缩进行解释,就是找出K种颜色去填充图片,使得图片尽可能还原原图。三种算法都是设K=16,当K越大就越接近原图(能用更多颜色去填充图片当然越接近原图,这很好理解)。
我将会用卡卡的一张图片进行示范,图片有点模糊。
2、利用sklearn.cluster进行图片压缩
这种方法根本不需要李姐算法,直接输入K值调用sklearn库就可以了。代码如下:
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
from sklearn.cluster import KMeans
import numpy as np
from PIL import Image
def compress(K):
path = "./kaka.jpg" # 图片路径,可自行调整
shape = image_to_matrix(path)
pixel = np.asarray(Image.open(path).getdata(), dtype='float')
pixel = pixel.reshape((shape[0] * shape[1], shape[2]))
kmeans = KMeans(n_clusters=K, random_state=0).fit(pixel)
newPixel = []
for i in kmeans.labels_:
newPixel.append(list(kmeans.cluster_centers_[i, :]/255))
newPixel = np.array(newPixel)
newPixel = newPixel.reshape((shape[0], shape[1], shape[2]))
plt.imshow(newPixel)
plt.show()
def image_to_matrix(path):
'''
加载图片
:param path: 图片路径
:return: 返回图片格式
'''
img = Image.open(path)
w, h = img.size
if img.mode == 'RGB': # 图像模式为RGB
shape = (h, w, 3)
elif img.mode == 'L': # 图像模式为L,即是灰度图
shape = (h, w, 1)
return shape
# 入口
if __name__ == '__main__':
K = input("请输入K值:")
if not K: # 如果没有输入K,默认K=32
K = 32
K = int(K)
compress(K)
K=16时,压缩效果如下:
3、利用经典K-Means算法进行图片压缩
经典K-Means算法就是上面第一点介绍的K-Means工作原理。代码如下:
import datetime
import random
import numpy as np
import matplotlib.pyplot as plt # plt 用于显示图片
from PIL import Image
def compress(K=16):
'''
压缩图像
:param K: K簇,压缩后图像使用K个像素点表示全图, K默认为16
:return: None
'''
start = datetime.datetime.now()
path = "./kaka.jpg" # 图片路径,可自行调整
dataSet, shape = image_to_matrix(path) # 加载图片,初始化图像矩阵
centroidList = initCentroids(dataSet, K) # 从数据集中随机选取k个数据
cen_idx, newVar = find_closest_centroids(dataSet, centroidList)
oldVar = 1
time = 1;
# 本次距离总和-上一次的距离总和>=1000,可自行调整,值越大运行时间越短,相应图像精度就越粗糙
while abs(newVar - oldVar) >= 1000:
centroidList = move_centroids(dataSet, cen_idx, K) # 更新簇中心
oldVar = newVar
cen_idx, newVar = find_closest_centroids(dataSet, centroidList) # 从数据集中随机选取k个数据返回
time += 1
print("一共计算了%d次" % time)
newImgData = fill_data(cen_idx, centroidList, shape)
showCluster(newImgData, shape) # 图像矩阵赋‘最恰当’的值
end = datetime.datetime.now() # 绘制图片
print("一共花费时间%s秒" % (str(end-start)))
def showCluster(newImgData, shape):
'''
绘制图片
:param newImgData: 经过K-Means计算出来的图像矩阵
:param shape: 图像格式
:return: None
'''
newPixel = np.array((newImgData / 255).tolist())
newPixel = newPixel.reshape(shape)
plt.imshow(newPixel)
plt.show()
def fill_data(cen_idx, centroidList, shape):
'''
图像矩阵赋‘最恰当’的值
:param cen_idx: 图像每个像素点对应的簇中心下标
:param centroidList: K个中心质点
:param shape: 图像格式
:return: None
'''
m = cen_idx.shape[0]
channel = shape[2]
newImgData = np.zeros((m, channel))
for i in range(m):
newImgData[i] = centroidList[int(cen_idx[i])]
return newImgData
def image_to_matrix(path):
'''
加载图片
:param path: 图片路径
:return: 图像矩阵 (h * w, c)
'''
img = Image.open(path)
w, h = img.size
if img.mode == 'RGB': # 图像模式为RGB
shape = (h, w, 3)
elif img.mode == 'L': # 图像模式为L,即是灰度图
shape = (h, w, 1)
return np.asmatrix(img.getdata(), dtype='float'), shape
def initCentroids(dataSet, K):
'''
从数据集中随机选取k个数据返回
:param dataSet: 图像矩阵(h*w, c)
:param k: K簇
:return: 返回K个随机数据
'''
m, n = dataSet.shape
sh = list(range(m))
random.shuffle(sh)
centroids = []
for i in range(K):
centroids.append(dataSet[sh[i]])
return (np.asarray(centroids))[:, 0, :] # list 转换成Matrix
def find_closest_centroids(img, centroidList):
'''
计算图像每个像素点对应的簇中心下标
:param img:
:param centroidList: (h*w, 1)
:return: 图像每个像素点对应的簇中心下标,以及距离
'''
m, n = img.shape
K, _ = centroidList.shape
sum = 0.0
cen_idx = np.zeros([m, 1], dtype='int')
for i in range(m):
M = np.full([K, n], img[i]) # 生成K行img[i]的矩阵,每行都一样
err = M - centroidList # 矩阵和每个簇中心相减
distance = np.sqrt(np.multiply(err, err).sum(axis=1)) # 每个数值平方再累加
cen_idx[i] = distance.argmin() # 最小列索引
sum += distance[cen_idx[i]]
return cen_idx, sum
def move_centroids(dataSet, cen_idx, K):
'''
更新簇中心
:param dataSet: 图像矩阵(h*w, c)
:param cen_idx: 图像每个像素点对应的簇中心下标
:param K: K簇
:return: 簇中心矩阵 (K, c)
'''
m, n = dataSet.shape
times = np.zeros([K, 1], dtype='int')
centroids = np.zeros([K, n], dtype='float')
for i in range(m):
idx = int(cen_idx[i])
centroids[idx] = centroids[idx] + dataSet[i]
times[idx] = times[idx] + 1
while times.min() == 0: # 避免除零异常
times[times.argmin()] = 1
centroids = centroids / times
return centroids
# 入口
if __name__ == '__main__':
K = input("请输入K值:")
if not K: # 如果没有输入K,默认K=32
K = 32
K = int(K)
compress(K)
K=16时,效果图如下:
4、利用全局K-Means算法进行图片压缩
全局K-Means算法是在经典K-Means的基础上实现的,就是在经典K-Means算法外面再加一层:设K=16,先随机生成一个点A,点A进过经典K-Means算法得到点A1;再随机生成B1,组成(A1,B1)再经过K-Means算法得到点(A2,B2)…直到算出(A16,B16,C16…P16)为止。代码如下:
import datetime
import random
import numpy as np
import matplotlib.pyplot as plt # plt 用于显示图片
from PIL import Image
def compress(K=16):
'''
压缩图像
:param K: K簇,压缩后图像使用K个像素点表示全图, K默认为16
:return: None
'''
start = datetime.datetime.now()
path = "./kaka.jpg" # 图片路径,可自行调整
dataSet, shape = image_to_matrix(path) # 加载图片,初始化图像矩阵
allTime = 0
for i in range(K):
if i == 0: # 当K=1时
centroidList = initCentroids(dataSet) # 从数据集中随机选取1个数据
else: # 当K!=1时
centroidList = np.r_[centroidList, initCentroids(dataSet)] # 簇中心矩阵添加一个随机数据
cen_idx, newVar = find_closest_centroids(dataSet, centroidList)
oldVar = 1
time = 1
allTime += 1
print("K=%d时" % (i+1))
# 本次距离总和-上一次的距离总和>=1000,可自行调整,值越大运行时间越短,相应图像精度就越粗糙
while abs(newVar - oldVar) >= 1000:
centroidList = move_centroids(dataSet, cen_idx, i+1) # 更新簇中心
oldVar = newVar
cen_idx, newVar = find_closest_centroids(dataSet, centroidList) # 从数据集中随机选取k个数据返回
time += 1
allTime += 1
print("计算了%d次" % time)
print("一共计算了%d次" % allTime)
newImgData = fill_data(cen_idx, centroidList, shape)
showCluster(newImgData, shape) # 图像矩阵赋‘最恰当’的值
end = datetime.datetime.now() # 绘制图片
print("一共花费时间%s秒" % (str(end-start)))
def showCluster(newImgData, shape):
'''
绘制图片
:param newImgData: 经过K-Means计算出来的图像矩阵
:param shape: 图像格式
:return: None
'''
newPixel = np.array((newImgData / 255).tolist())
newPixel = newPixel.reshape(shape)
plt.imshow(newPixel)
plt.show()
def fill_data(cen_idx, centroidList, shape):
'''
图像矩阵赋‘最恰当’的值
:param cen_idx: 图像每个像素点对应的簇中心下标
:param centroidList: K个中心质点
:param shape: 图像格式
:return: None
'''
m = cen_idx.shape[0]
channel = shape[2]
newImgData = np.zeros((m, channel))
for i in range(m):
newImgData[i] = centroidList[int(cen_idx[i])]
return newImgData
def image_to_matrix(path):
'''
加载图片
:param path: 图片路径
:return: 图像矩阵 (h * w, c)
'''
img = Image.open(path)
w, h = img.size
if img.mode == 'RGB': # 图像模式为RGB
shape = (h, w, 3)
elif img.mode == 'L': # 图像模式为L,即是灰度图
shape = (h, w, 1)
return np.asmatrix(img.getdata(), dtype='float'), shape
def initCentroids(dataSet):
'''
从数据集中随机选取k个数据返回
:param dataSet: 图像矩阵(h*w, c)
:return: 返回1个随机数据
'''
m, _ = dataSet.shape
x = random.randint(0, m) # 0至m随机取整数
return np.asarray(dataSet[x])
def find_closest_centroids(img, centroidList):
'''
计算图像每个像素点对应的簇中心下标
:param img:
:param centroidList: (h*w, 1)
:return: 图像每个像素点对应的簇中心下标,以及距离
'''
m, n = img.shape
K, _ = centroidList.shape
sum = 0.0
cen_idx = np.zeros([m, 1], dtype='int')
for i in range(m):
M = np.full([K, n], img[i]) # 生成K行img[i]的矩阵,每行都一样
err = M - centroidList # 矩阵和每个簇中心相减
distance = np.sqrt(np.multiply(err, err).sum(axis=1)) # 每个数值平方再累加
cen_idx[i] = distance.argmin() # 最小列索引
sum += distance[cen_idx[i]]
return cen_idx, sum
def move_centroids(dataSet, cen_idx, K):
'''
更新簇中心
:param dataSet: 图像矩阵(h*w, c)
:param cen_idx: 图像每个像素点对应的簇中心下标
:param K: K簇
:return: 簇中心矩阵 (K, c)
'''
m, n = dataSet.shape
times = np.zeros([K, 1], dtype='int')
centroids = np.zeros([K, n], dtype='float')
for i in range(m):
idx = int(cen_idx[i])
centroids[idx] = centroids[idx] + dataSet[i]
times[idx] = times[idx] + 1
while times.min() == 0: # 避免除零异常
times[times.argmin()] = 1
centroids = centroids / times
return centroids
# 入口
if __name__ == '__main__':
K = input("请输入K值:")
if not K: # 如果没有输入K,默认K=32
K = 32
K = int(K)
compress(K)
K=16时,压缩效果如下:
5、总结
K-Means算法并不难理解,我也是参考了很多别人的代码,代码也是有很多瑕疵,运行速度也不够快,直接调用sklearn的不用多说是最快的,毕竟是拿别人的直接用,经典K-Means和全局K-Means是真的慢。后期有机会优化下,刚入坑学习机器学习,希望各位大佬多给意见,一起学习啊。
参考大佬网址:
https://blog.youkuaiyun.com/qq_30374549/article/details/81005872
https://blog.youkuaiyun.com/xholes/article/details/70342133
https://juejin.im/post/5a97d129f265da239235c2eb
https://github.com/SherlockUnknowEn/ImageCompress