【机器学习】案例1.5——KMeans聚类,实现图像压缩

AI助手已提取文章相关产品:

一、项目背景与解决方案

1. 项目背景

数字图像的RGB色彩空间中,每个像素由红、绿、蓝三个通道(各0-255取值)组成,理论上包含1600多万种颜色。但实际图像中大量像素的颜色存在高度冗余,直接存储所有像素的完整颜色信息会占用过多存储空间。因此,图像颜色量化(矢量量化) 成为图像压缩的关键技术之一:通过减少图像中颜色的数量,在视觉效果损失可控的前提下降低存储成本。

2. 解决的核心问题
  • 可视化图像颜色的三维分布特征(RGB三个维度),直观呈现颜色冗余性;
  • 基于KMeans聚类算法对像素颜色进行聚类,用少量聚类中心(量化颜色)替代原始颜色,实现图像压缩;
  • 对比原始图像与量化后的图像,验证颜色量化的效果。
3. 技术方案
  • 读取图像并将像素值归一化到0-1区间,转换为二维数组(像素数×3);
  • histogramdd和3D散点图可视化颜色的三维频数分布,用折线图展示颜色频数的排序特征;
  • 采用KMeans聚类对像素颜色聚类(随机采样部分像素训练聚类模型,提升效率);
  • 用聚类中心替代原像素颜色,还原量化后的图像并与原图对比展示。

二、带详细注释的代码

#!/usr/bin/python
# -*- coding: utf-8 -*-
# 声明Python解释器编码格式为UTF-8,避免中文注释乱码

# 导入核心库
from PIL import Image  # Python Imaging Library,用于图像读取、处理
import numpy as np     # 数值计算库,处理图像数组、矩阵运算
from sklearn.cluster import KMeans  # 导入KMeans聚类算法
import matplotlib  # 绘图基础库,配置中文显示
import matplotlib.pyplot as plt  # 绘图核心库,绘制图像、图表
from mpl_toolkits.mplot3d import Axes3D  # 3D绘图扩展库,绘制RGB三维散点图


def restore_image(cb, cluster, shape):
    """
    根据聚类中心和聚类结果,还原量化后的图像
    :param cb: 聚类中心数组(量化后的颜色集合),形状为(聚类数, 3)
    :param cluster: 每个像素的聚类标签数组,形状为(像素总数,)
    :param shape: 原始图像的形状,格式为(行数, 列数, 3)
    :return: 量化后的图像数组,形状与原始图像一致
    """
    # 解析原始图像的行数、列数(第三个维度为通道数,用dummy占位)
    row, col, dummy = shape
    # 创建空数组存储量化后的图像,维度与原图一致
    image = np.empty((row, col, 3))
    index = 0  # 像素索引,遍历所有像素
    # 逐行、逐列遍历像素
    for r in range(row):
        for c in range(col):
            # 用当前像素聚类标签对应的聚类中心,替换原像素颜色
            image[r, c] = cb[cluster[index]]
            index += 1  # 索引自增,处理下一个像素
    return image


def show_scatter(a):
    """
    绘制图像颜色的三维频数散点图和频数排序折线图
    :param a: 图像像素的二维数组,形状为(像素总数, 3)(每行对应一个像素的RGB值)
    """
    N = 10  # 三维直方图的分箱数,控制散点图的粒度
    print('原始像素颜色数据(前几行):\n', a[:5])  # 打印少量数据示例

    # 计算RGB三维空间的频数直方图
    # density: 每个分箱的频数;edges: 每个维度的分箱边界
    density, edges = np.histogramdd(a, bins=[N, N, N], range=[(0, 1), (0, 1), (0, 1)])
    density /= density.max()  # 频数归一化(0-1),方便散点大小映射

    # 生成三维网格坐标,用于3D散点图的位置
    x = y = z = np.arange(N)
    d = np.meshgrid(x, y, z)

    # 创建第一个画布,绘制3D散点图
    fig = plt.figure(1, facecolor='w')  # facecolor='w'设置背景为白色
    ax = fig.add_subplot(111, projection='3d')  # 创建3D子图
    # 绘制3D散点图:x/y/z为网格坐标,c='r'红色,s为散点大小(映射频数),depthshade=True开启深度阴影
    ax.scatter(d[1], d[0], d[2], c='r', s=100*density, marker='o', depthshade=True)
    # 设置坐标轴标签(中文)
    ax.set_xlabel(u'红色分量')
    ax.set_ylabel(u'绿色分量')
    ax.set_zlabel(u'蓝色分量')
    plt.title(u'图像颜色三维频数分布', fontsize=20)  # 设置标题

    # 创建第二个画布,绘制频数排序折线图
    plt.figure(2, facecolor='w')
    den = density[density > 0]  # 过滤掉频数为0的分箱
    den = np.sort(den)[::-1]  # 降序排序频数
    t = np.arange(len(den))  # 生成横坐标(频数排名)
    # 绘制折线图:红色折线+绿色圆点,线宽lw=2
    plt.plot(t, den, 'r-', t, den, 'go', lw=2)
    plt.title(u'图像颜色频数分布(降序)', fontsize=18)
    plt.grid(True)  # 显示网格

    plt.show()  # 展示所有图表


if __name__ == '__main__':
    # ===================== 全局配置:解决matplotlib中文显示问题 =====================
    matplotlib.rcParams['font.sans-serif'] = [u'SimHei']  # 设置默认字体为黑体
    matplotlib.rcParams['axes.unicode_minus'] = False  # 解决负号显示为方块的问题

    # ===================== 核心参数设置 =====================
    num_vq = 20  # 矢量量化的颜色数量(KMeans聚类数),可调整(如50/100)

    # ===================== 图像读取与预处理 =====================
    im = Image.open('../data/Lena.png')  # 读取图像(路径需根据实际调整)
    # 将图像转换为numpy数组,类型转为float,归一化到0-1区间(方便聚类计算)
    image = np.array(im).astype(np.float) / 255
    # 将图像从三维数组(行, 列, 3)重塑为二维数组(像素总数, 3),每行对应一个像素的RGB值
    image_v = image.reshape((-1, 3))

    # ===================== 初始化KMeans模型 & 可视化颜色分布 =====================
    model = KMeans(num_vq)  # 初始化KMeans模型,指定聚类数
    show_scatter(image_v)  # 调用函数绘制颜色分布图表

    # ===================== KMeans聚类训练 & 预测 =====================
    N = image_v.shape[0]  # 获取图像总像素数
    # 随机采样1000个像素(避免全量像素训练耗时),提升聚类效率
    idx = np.random.randint(0, N, size=1000)
    image_sample = image_v[idx]  # 采样后的像素数据
    model.fit(image_sample)  # 用采样数据训练KMeans模型,得到聚类中心
    c = model.predict(image_v)  # 对所有像素预测聚类标签(每个像素属于哪个聚类)

    # 打印聚类结果信息(调试用)
    print('所有像素的聚类标签(前10个):\n', c[:10])
    print('聚类中心(量化后的颜色):\n', model.cluster_centers_)

    # ===================== 绘制原始图像 & 量化后图像 =====================
    plt.figure(figsize=(15, 8), facecolor='w')  # 创建画布,设置尺寸15×8

    # 子图1:原始图像
    plt.subplot(121)  # 1行2列,第1个子图
    plt.axis('off')  # 关闭坐标轴显示
    plt.title(u'原始图片', fontsize=18)  # 设置标题
    plt.imshow(image)  # 显示原始图像

    # 子图2:量化后的图像
    plt.subplot(122)  # 1行2列,第2个子图
    # 调用restore_image函数,还原量化后的图像
    vq_image = restore_image(model.cluster_centers_, c, image.shape)
    plt.axis('off')  # 关闭坐标轴显示
    plt.title(u'矢量量化后图片:%d色' % num_vq, fontsize=20)  # 标题标注量化颜色数
    plt.imshow(vq_image)  # 显示量化后的图像

    plt.tight_layout(1.2)  # 调整子图间距,避免重叠
    plt.show()  # 展示对比图像

三、关键说明

  1. 聚类采样优化:代码中随机采样1000个像素训练KMeans,而非全量像素,既保证聚类效果,又大幅降低计算耗时;
  2. 参数可调性num_vq(量化颜色数)可根据需求调整(如设为16/32/50),数值越小压缩率越高,但视觉损失越大;
  3. 可视化价值:颜色的3D频数分布可直观展示图像的主色调分布,频数折线图可看出大部分颜色的频数极低(冗余性),验证了颜色量化的合理性。

您可能感兴趣的与本文相关内容

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值