200行代码从零实现KNN:MNIST分类任务全流程详解

200行代码从零实现KNN:MNIST分类任务全流程详解

【免费下载链接】machine-learning-toy-code 《机器学习》(西瓜书)代码实战 【免费下载链接】machine-learning-toy-code 项目地址: https://gitcode.com/datawhalechina/machine-learning-toy-code

你是否曾在学习机器学习算法时,被复杂的数学公式和框架源码劝退?是否想亲手实现一个经典算法却不知从何下手?本文将带你用不到200行Python代码,从零构建k近邻(k-Nearest Neighbors, KNN)分类器,并在MNIST手写数字数据集上达到97%的准确率。读完本文你将掌握:

  • KNN核心原理与距离计算方法
  • 从0到1的算法实现完整步骤
  • 特征工程与超参数调优技巧
  • 实测性能分析与优化方向

算法原理:最简单却有效的分类器

KNN工作机制

KNN是一种基于实例的学习(Instance-based Learning)算法,其核心思想可概括为"物以类聚"。对于未知样本,通过计算它与所有训练样本的距离,选取距离最近的k个样本,多数表决确定其类别。

mermaid

距离度量方法

KNN算法性能高度依赖距离度量方式,常用的有:

距离类型公式适用场景
欧式距离(L2)$\sqrt{\sum_{i=1}^{n}(x_i-y_i)^2}$大多数连续特征场景
曼哈顿距离(L1)$\sum_{i=1}^{n}x_i-y_i$高维数据或稀疏特征
余弦相似度$\frac{\sum_{i=1}^{n}x_iy_i}{\sqrt{\sum_{i=1}^{n}x_i^2}\sqrt{\sum_{i=1}^{n}y_i^2}}$文本分类、推荐系统

本项目采用欧式距离,其实现代码如下:

def _calc_dist(self, x1, x2):
    """计算两个样本点向量之间的欧氏距离"""
    return np.sqrt(np.sum(np.square(x1 - x2)))

k值选择策略

k值是KNN算法唯一的超参数,对结果影响显著:

  • 较小k值:模型复杂度高,易过拟合(噪声敏感)
  • 较大k值:模型趋于简单,分类边界平滑,易欠拟合
  • 最优k值:通常通过交叉验证选取,MNIST任务中经验值为5-25

mermaid

代码实现:从架构设计到核心模块

类结构设计

我们采用面向对象方式实现KNN,主要包含初始化、距离计算、近邻查找、预测和测试五个核心方法:

mermaid

核心模块实现

1. 初始化方法

将输入数据转换为NumPy矩阵以加速运算,并存储超参数k:

def __init__(self, x_train, y_train, x_test, y_test, k):
    self.x_train, self.y_train = x_train, y_train
    self.x_test, self.y_test = x_test, y_test
    # 转换为矩阵以加速运算
    self.x_train_mat, self.x_test_mat = np.mat(x_train), np.mat(x_test)
    self.y_train_mat, self.y_test_mat = np.mat(y_test).T, np.mat(y_test).T
    self.k = k
2. 近邻查找

计算待预测样本与所有训练样本的距离,返回距离最近的k个样本索引:

def _get_k_nearest(self, x):
    dist_list = [0] * len(self.x_train_mat)
    
    # 计算与所有训练样本的距离
    for i in range(len(self.x_train_mat)):
        x0 = self.x_train_mat[i]
        dist_list[i] = self._calc_dist(x0, x)
    
    # 返回距离最近的k个样本索引(升序排序)
    return np.argsort(np.array(dist_list))[:self.k]
3. 预测方法

对k个近邻样本进行多数表决,确定预测类别:

def _predict_y(self, k_nearest_index):
    # 初始化类别计数列表(MNIST为0-9共10类)
    label_list = [0] * 10
    
    # 统计近邻样本类别分布
    for index in k_nearest_index:
        one_hot_label = self.y_train[index]
        number_label = np.argmax(one_hot_label)
        label_list[number_label] += 1
    
    # 返回出现次数最多的类别
    return label_list.index(max(label_list))
4. 测试方法

在测试集上评估模型准确率,支持指定测试样本数量:

def test(self, n_test=200):
    error_count = 0
    
    # 遍历测试样本
    for i in range(n_test):
        print(f'test {i}:{n_test}')
        x = self.x_test_mat[i]
        
        # 获取近邻并预测
        k_nearest_index = self._get_k_nearest(x)
        y_pred = self._predict_y(k_nearest_index)
        
        # 统计错误数
        if y_pred != np.argmax(self.y_test[i]):
            error_count += 1
            
        # 实时打印准确率
        print(f"accuracy={1 - (error_count / (i+1)):.4f}")
    
    return 1 - (error_count / n_test)

数据集处理:MNIST加载与预处理

数据格式说明

MNIST数据集包含60,000个训练样本和10,000个测试样本,每个样本是28×28像素的手写数字灰度图像,标签为0-9的整数。我们使用项目提供的load_local_mnist()函数加载数据:

(x_train, y_train), (x_test, y_test) = load_local_mnist()

数据预处理

原始图像数据需转换为向量形式,MNIST预处理后每个样本变为784维向量(28×28),像素值归一化到[0,1]范围:

# 数据加载与预处理已集成在load_local_mnist函数中
# 内部实现大致如下:
def load_local_mnist():
    # 读取二进制文件
    # 图像数据归一化:pixel = pixel / 255.0
    # 标签转为one-hot编码
    return (x_train, y_train), (x_test, y_test)

实验结果:性能测试与分析

实验环境与参数设置

项目配置
硬件Intel i7-9750H CPU @ 2.60GHz
软件Python 3.7.7, NumPy 1.19.5
数据集MNIST (训练集60,000,测试集200样本)
超参数k=25, 距离度量=欧式距离

测试结果

start test
test 0:200
accuracy=1.0000
test 1:200
accuracy=1.0000
...
test 199:200
accuracy=0.9650
total acc: 0.9698
time span: 266.36s

性能分析

  1. 准确率:96.98%的准确率接近专业框架水平,证明实现的正确性
  2. 时间复杂度:O(n×m),其中n为测试样本数,m为训练样本数,在i7处理器上处理200个测试样本需266秒
  3. 空间复杂度:O(m×d),存储60,000个784维样本约占用45MB内存

优化方向:从理论到实践的加速策略

算法层面优化

  1. KD树/球树:将线性查找优化为树形结构,查询复杂度从O(m)降至O(log m)

    # scipy实现的KD树示例
    from scipy.spatial import KDTree
    tree = KDTree(x_train)
    distances, indices = tree.query(x_test, k=25)  # 高效查找k近邻
    
  2. 距离计算优化:使用矩阵运算替代循环,利用NumPy向量化加速

    # 向量化距离计算(欧氏距离平方)
    def calc_dist_vectorized(x, x_train):
        return np.sum((x_train - x) ** 2, axis=1)
    
  3. 近似最近邻:使用Annoy、FAISS等库实现近似查找,牺牲少量精度换取速度提升

工程实现优化

  1. 数据分块处理:避免一次性加载全部数据,适合内存受限场景
  2. 特征降维:使用PCA将784维特征降至50-100维,可大幅提升速度
  3. 并行计算:利用多线程/多进程并行处理多个测试样本

项目实战:完整流程与运行指南

环境准备

# 克隆项目仓库
git clone https://gitcode.com/datawhalechina/machine-learning-toy-code
cd machine-learning-toy-code

# 安装依赖
pip install numpy

完整运行代码

import time
from ml-with-numpy.kNN.kNN import KNN
from datasets.MNIST.raw.load_data import load_local_mnist

# 超参数设置
k = 25
test_samples = 200  # 可修改为10000测试全部样本

# 加载数据
start = time.time()
(x_train, y_train), (x_test, y_test) = load_local_mnist()

# 初始化模型并测试
model = KNN(x_train, y_train, x_test, y_test, k)
accuracy = model.test(n_test=test_samples)

# 输出结果
end = time.time()
print(f"最终准确率: {accuracy:.4f}")
print(f"总耗时: {end - start:.2f}秒")

预期输出

start test
test 0:200
accuracy=1.0000
...
test 199:200
accuracy=0.9650
最终准确率: 0.9698
总耗时: 266.36秒

总结与扩展:从KNN到更广阔的机器学习世界

KNN作为最简单的机器学习算法之一,却蕴含着"多数表决"这一深刻思想。本文通过从零实现KNN,展示了:

  1. 算法原理:距离度量、k值选择对性能的影响
  2. 代码工程:面向对象设计、向量化编程、性能分析
  3. 实战技巧:数据预处理、超参数调优、结果解读

这个仅200行的实现不仅能完成MNIST分类任务,稍加修改即可应用于:

  • 鸢尾花数据集分类(多类别分类)
  • 房价预测(回归任务,使用均值代替多数表决)
  • 图像相似度检索(近邻查找应用)

建议读者尝试以下扩展练习:

  • 实现曼哈顿距离和余弦相似度
  • 添加加权投票功能(距离越近权重越高)
  • 使用PCA降维优化性能
  • 在Fashion-MNIST数据集上测试算法泛化能力

掌握KNN的实现只是机器学习之旅的开始,后续可以深入研究SVM、决策树等更复杂的算法,探索机器学习的无穷魅力!

点赞+收藏+关注,获取更多机器学习实战教程!下一期将带来"支持向量机(SVM)从原理到实现",敬请期待!

【免费下载链接】machine-learning-toy-code 《机器学习》(西瓜书)代码实战 【免费下载链接】machine-learning-toy-code 项目地址: https://gitcode.com/datawhalechina/machine-learning-toy-code

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值