机器学习之利用K-Means算法进行图片压缩

1、K-Means算法工作原理

为了更好的进行解释K-Means算法工作原理,这里引用大佬的两张图片,大佬地址:https://blog.youkuaiyun.com/qq_30374549/article/details/81005872在这里插入图片描述
可以看到上面的图片总体分3个区域,可以说分三类,更专业的说法是分为三个族。K就是指这个3,我们的任务是找到这3个区域的中心点。怎么找呢?
在这里插入图片描述
上面的图是找出3个中心点的过程:我们先随机生成3个点(假如是A,B,C点),先循环图中所有点利用两点间距离公式分别与点A、B、C算出距离,找出距离最短的进行归类(如D是其中一点,AD=3,BD=2,CD=7,则D属于B类),再将三类的点分别求平均得到A1,B1,C1。一直如此循环直到An,Bn,Cn的值较为稳定。

那要怎么实现呢?我用三种K-Means算法实现图片压缩进行解释,就是找出K种颜色去填充图片,使得图片尽可能还原原图。三种算法都是设K=16,当K越大就越接近原图(能用更多颜色去填充图片当然越接近原图,这很好理解)。
我将会用卡卡的一张图片进行示范,图片有点模糊。

2、利用sklearn.cluster进行图片压缩

这种方法根本不需要李姐算法,直接输入K值调用sklearn库就可以了。代码如下:

import matplotlib.pyplot as plt  # plt 用于显示图片
import matplotlib.image as mpimg  # mpimg 用于读取图片
from sklearn.cluster import KMeans
import numpy as np
from PIL import Image

def compress(K):
    path = "./kaka.jpg"     # 图片路径,可自行调整
    shape = image_to_matrix(path)
    pixel = np.asarray(Image.open(path).getdata(), dtype='float')
    pixel = pixel.reshape((shape[0] * shape[1], shape[2]))

    kmeans = KMeans(n_clusters=K, random_state=0).fit(pixel)

    newPixel = []
    for i in kmeans.labels_:
        newPixel.append(list(kmeans.cluster_centers_[i, :]/255))
    newPixel = np.array(newPixel)
    newPixel = newPixel.reshape((shape[0], shape[1], shape[2]))

    plt.imshow(newPixel)
    plt.show()

def image_to_matrix(path):
    '''
    加载图片
    :param path: 图片路径
    :return: 返回图片格式
    '''
    img = Image.open(path)
    w, h = img.size
    if img.mode == 'RGB':   # 图像模式为RGB
        shape = (h, w, 3)
    elif img.mode == 'L':   # 图像模式为L,即是灰度图
        shape = (h, w, 1)
    return shape

# 入口
if __name__ == '__main__':
    K = input("请输入K值:")
    if not K:   # 如果没有输入K,默认K=32
        K = 32
    K = int(K)
    compress(K)

K=16时,压缩效果如下:
在这里插入图片描述

3、利用经典K-Means算法进行图片压缩

经典K-Means算法就是上面第一点介绍的K-Means工作原理。代码如下:

import datetime
import random

import numpy as np
import matplotlib.pyplot as plt  # plt 用于显示图片
from PIL import Image


def compress(K=16):
    '''
    压缩图像
    :param K: K簇,压缩后图像使用K个像素点表示全图, K默认为16
    :return: None
    '''
    start = datetime.datetime.now()
    path = "./kaka.jpg"     # 图片路径,可自行调整
    dataSet, shape = image_to_matrix(path)      # 加载图片,初始化图像矩阵
    centroidList = initCentroids(dataSet, K)    # 从数据集中随机选取k个数据
    cen_idx, newVar = find_closest_centroids(dataSet, centroidList)
    oldVar = 1
    time = 1;
    # 本次距离总和-上一次的距离总和>=1000,可自行调整,值越大运行时间越短,相应图像精度就越粗糙
    while abs(newVar - oldVar) >= 1000:
        centroidList = move_centroids(dataSet, cen_idx, K)     # 更新簇中心
        oldVar = newVar
        cen_idx, newVar = find_closest_centroids(dataSet, centroidList)     # 从数据集中随机选取k个数据返回
        time += 1
    print("一共计算了%d次" % time)
    newImgData = fill_data(cen_idx, centroidList, shape)
    showCluster(newImgData, shape)      # 图像矩阵赋‘最恰当’的值
    end = datetime.datetime.now()       # 绘制图片
    print("一共花费时间%s秒" % (str(end-start)))


def showCluster(newImgData, shape):
    '''
    绘制图片
    :param newImgData: 经过K-Means计算出来的图像矩阵
    :param shape: 图像格式
    :return: None
    '''
    newPixel = np.array((newImgData / 255).tolist())
    newPixel = newPixel.reshape(shape)
    plt.imshow(newPixel)
    plt.show()

def fill_data(cen_idx, centroidList, shape):
    '''
    图像矩阵赋‘最恰当’的值
    :param cen_idx: 图像每个像素点对应的簇中心下标
    :param centroidList: K个中心质点
    :param shape: 图像格式
    :return: None
    '''
    m = cen_idx.shape[0]
    channel = shape[2]
    newImgData = np.zeros((m, channel))
    for i in range(m):
        newImgData[i] = centroidList[int(cen_idx[i])]
    return newImgData


def image_to_matrix(path):
    '''
    加载图片
    :param path: 图片路径
    :return: 图像矩阵 (h * w, c)
    '''
    img = Image.open(path)
    w, h = img.size
    if img.mode == 'RGB':   # 图像模式为RGB
        shape = (h, w, 3)
    elif img.mode == 'L':   # 图像模式为L,即是灰度图
        shape = (h, w, 1)
    return np.asmatrix(img.getdata(), dtype='float'), shape


def initCentroids(dataSet, K):
    '''
    从数据集中随机选取k个数据返回
    :param dataSet: 图像矩阵(h*w, c)
    :param k: K簇
    :return: 返回K个随机数据
    '''
    m, n = dataSet.shape
    sh = list(range(m))
    random.shuffle(sh)
    centroids = []
    for i in range(K):
        centroids.append(dataSet[sh[i]])
    return (np.asarray(centroids))[:, 0, :]  # list 转换成Matrix


def find_closest_centroids(img, centroidList):
    '''
    计算图像每个像素点对应的簇中心下标
    :param img:
    :param centroidList: (h*w, 1)
    :return: 图像每个像素点对应的簇中心下标,以及距离
    '''
    m, n = img.shape
    K, _ = centroidList.shape
    sum = 0.0
    cen_idx = np.zeros([m, 1],  dtype='int')
    for i in range(m):
        M = np.full([K, n], img[i])  # 生成K行img[i]的矩阵,每行都一样
        err = M - centroidList    # 矩阵和每个簇中心相减
        distance = np.sqrt(np.multiply(err, err).sum(axis=1))    # 每个数值平方再累加
        cen_idx[i] = distance.argmin()      # 最小列索引
        sum += distance[cen_idx[i]]
    return cen_idx, sum


def move_centroids(dataSet, cen_idx, K):
    '''
    更新簇中心
    :param dataSet: 图像矩阵(h*w, c)
    :param cen_idx: 图像每个像素点对应的簇中心下标
    :param K: K簇
    :return: 簇中心矩阵 (K, c)
    '''
    m, n = dataSet.shape
    times = np.zeros([K, 1], dtype='int')
    centroids = np.zeros([K, n], dtype='float')
    for i in range(m):
        idx = int(cen_idx[i])
        centroids[idx] = centroids[idx] + dataSet[i]
        times[idx] = times[idx] + 1
    while times.min() == 0:  # 避免除零异常
        times[times.argmin()] = 1
    centroids = centroids / times

    return centroids

# 入口
if __name__ == '__main__':
    K = input("请输入K值:")
    if not K:   # 如果没有输入K,默认K=32
        K = 32
    K = int(K)
    compress(K)

K=16时,效果图如下:
在这里插入图片描述

4、利用全局K-Means算法进行图片压缩

全局K-Means算法是在经典K-Means的基础上实现的,就是在经典K-Means算法外面再加一层:设K=16,先随机生成一个点A,点A进过经典K-Means算法得到点A1;再随机生成B1,组成(A1,B1)再经过K-Means算法得到点(A2,B2)…直到算出(A16,B16,C16…P16)为止。代码如下:

import datetime
import random

import numpy as np
import matplotlib.pyplot as plt  # plt 用于显示图片
from PIL import Image


def compress(K=16):
    '''
    压缩图像
    :param K: K簇,压缩后图像使用K个像素点表示全图, K默认为16
    :return: None
    '''
    start = datetime.datetime.now()
    path = "./kaka.jpg"     # 图片路径,可自行调整
    dataSet, shape = image_to_matrix(path)      # 加载图片,初始化图像矩阵
    allTime = 0
    for i in range(K):
        if i == 0:  # 当K=1时
            centroidList = initCentroids(dataSet)   # 从数据集中随机选取1个数据
        else:   # 当K!=1时
            centroidList = np.r_[centroidList, initCentroids(dataSet)]  # 簇中心矩阵添加一个随机数据
        cen_idx, newVar = find_closest_centroids(dataSet, centroidList)
        oldVar = 1
        time = 1
        allTime += 1
        print("K=%d时" % (i+1))
        # 本次距离总和-上一次的距离总和>=1000,可自行调整,值越大运行时间越短,相应图像精度就越粗糙
        while abs(newVar - oldVar) >= 1000:
            centroidList = move_centroids(dataSet, cen_idx, i+1)     # 更新簇中心
            oldVar = newVar
            cen_idx, newVar = find_closest_centroids(dataSet, centroidList)     # 从数据集中随机选取k个数据返回
            time += 1
            allTime += 1
        print("计算了%d次" % time)
    print("一共计算了%d次" % allTime)
    newImgData = fill_data(cen_idx, centroidList, shape)
    showCluster(newImgData, shape)      # 图像矩阵赋‘最恰当’的值
    end = datetime.datetime.now()       # 绘制图片
    print("一共花费时间%s秒" % (str(end-start)))


def showCluster(newImgData, shape):
    '''
    绘制图片
    :param newImgData: 经过K-Means计算出来的图像矩阵
    :param shape: 图像格式
    :return: None
    '''
    newPixel = np.array((newImgData / 255).tolist())
    newPixel = newPixel.reshape(shape)
    plt.imshow(newPixel)
    plt.show()

def fill_data(cen_idx, centroidList, shape):
    '''
    图像矩阵赋‘最恰当’的值
    :param cen_idx: 图像每个像素点对应的簇中心下标
    :param centroidList: K个中心质点
    :param shape: 图像格式
    :return: None
    '''
    m = cen_idx.shape[0]
    channel = shape[2]
    newImgData = np.zeros((m, channel))
    for i in range(m):
        newImgData[i] = centroidList[int(cen_idx[i])]
    return newImgData


def image_to_matrix(path):
    '''
    加载图片
    :param path: 图片路径
    :return: 图像矩阵 (h * w, c)
    '''
    img = Image.open(path)
    w, h = img.size
    if img.mode == 'RGB':   # 图像模式为RGB
        shape = (h, w, 3)
    elif img.mode == 'L':   # 图像模式为L,即是灰度图
        shape = (h, w, 1)
    return np.asmatrix(img.getdata(), dtype='float'), shape


def initCentroids(dataSet):
    '''
    从数据集中随机选取k个数据返回
    :param dataSet: 图像矩阵(h*w, c)
    :return: 返回1个随机数据
    '''
    m, _ = dataSet.shape
    x = random.randint(0, m)    # 0至m随机取整数
    return np.asarray(dataSet[x])


def find_closest_centroids(img, centroidList):
    '''
    计算图像每个像素点对应的簇中心下标
    :param img:
    :param centroidList: (h*w, 1)
    :return: 图像每个像素点对应的簇中心下标,以及距离
    '''
    m, n = img.shape
    K, _ = centroidList.shape
    sum = 0.0
    cen_idx = np.zeros([m, 1],  dtype='int')
    for i in range(m):
        M = np.full([K, n], img[i])  # 生成K行img[i]的矩阵,每行都一样
        err = M - centroidList    # 矩阵和每个簇中心相减
        distance = np.sqrt(np.multiply(err, err).sum(axis=1))    # 每个数值平方再累加
        cen_idx[i] = distance.argmin()      # 最小列索引
        sum += distance[cen_idx[i]]
    return cen_idx, sum


def move_centroids(dataSet, cen_idx, K):
    '''
    更新簇中心
    :param dataSet: 图像矩阵(h*w, c)
    :param cen_idx: 图像每个像素点对应的簇中心下标
    :param K: K簇
    :return: 簇中心矩阵 (K, c)
    '''
    m, n = dataSet.shape
    times = np.zeros([K, 1], dtype='int')
    centroids = np.zeros([K, n], dtype='float')
    for i in range(m):
        idx = int(cen_idx[i])
        centroids[idx] = centroids[idx] + dataSet[i]
        times[idx] = times[idx] + 1
    while times.min() == 0:  # 避免除零异常
        times[times.argmin()] = 1
    centroids = centroids / times

    return centroids

# 入口
if __name__ == '__main__':
    K = input("请输入K值:")
    if not K:   # 如果没有输入K,默认K=32
        K = 32
    K = int(K)
    compress(K)

K=16时,压缩效果如下:
在这里插入图片描述

5、总结

K-Means算法并不难理解,我也是参考了很多别人的代码,代码也是有很多瑕疵,运行速度也不够快,直接调用sklearn的不用多说是最快的,毕竟是拿别人的直接用,经典K-Means和全局K-Means是真的慢。后期有机会优化下,刚入坑学习机器学习,希望各位大佬多给意见,一起学习啊。

参考大佬网址:
https://blog.youkuaiyun.com/qq_30374549/article/details/81005872
https://blog.youkuaiyun.com/xholes/article/details/70342133
https://juejin.im/post/5a97d129f265da239235c2eb
https://github.com/SherlockUnknowEn/ImageCompress

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值