机器学习之knn算法----手写数字识别mnist

本文介绍了KNN算法在手写数字识别MNIST数据集的应用,详细探讨了KNeighborsClassifier参数对模型准确率的影响,包括n_neighbors、weights、algorithm等。通过实验展示了不同算法和K值的选择对预测效果的差异,并强调了计算效率和权重对模型性能的重要性。文章以实际代码运行结果为依据,展示了调整参数后的准确率变化,鼓励读者进行更多尝试和学习。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

机器学习之knn算法----手写数字识别mnist

knn介绍

KNeighborsClassifier(n_neighbors=5, weights=‘uniform’, algorithm=‘auto’, leaf_size=30, p=2, metric=‘minkowski’, metric_params=None, n_jobs=1, **kwargs)
Parameters

n_neighbors : int, (default = 5)—k的值

weights : str or callable, optional (default = ‘uniform’)

- 'uniform' : 默认投票是权重一致
- 'distance' : 距离越近权重越高

algorithm : {‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}, optional

- 'ball_tree' 球树,创建球树,减少计算次数
- 'kd_tree' `KDTree`,创建二叉树,减少计算次数
- 'brute' 暴力法
- 'auto' 自动选择。当然数量比较少时,还是使用暴力法

leaf_size :

-int, optional (default = 30)
-当使用 BallTree or KDTree.  这个是跟节点的数量,默认是30,当数据量很大时,可以加深节点

p : integer, optional (default = 2)

- p=1曼哈顿距离。p=2欧氏距离
- 应该都知道,没什么介绍的。一般都用欧氏距离
- 

metric : string or callable, default ‘minkowski’

- 闵可夫斯基,这就是一个公式,不同的p值对应不同的距离。p=2就是欧式

n_jobs : int, optional (default = 1)

-  让几个人去计算,-1就是你电脑有几个核就用几个。默认就是一个

进入正题

这些参数到底怎么影响knn的正确率呢?我们来看看。。。。

这里使用的是.mat的二进制文件。网上有下载,我后面也会上传

导入模块
import scipy.io as spio

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split

因为是mat文件,所以用scipy来获取:

mnist = spio.loadmat("./data/mnist.mat")
display(mnist)

输出:

{'__header__': b'MATLAB 5.0 MAT-file Platform: posix, Created on: Sun Mar 30 03:19:02 2014',
 '__version__': '1.0',
 '__globals__': [],
 'mldata_descr_ordering': array([[array(['label'], dtype='<U5'), array(['data'], dtype='<U4')]],
       dtype=object),
 'data': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
 'label': array([[0., 0., 0., ..., 9., 9., 9.]])}

这是一种MATLAB的格式文件,有兴趣的可以自行百度,在这不做重点
可以看出数据都是在data里,目标值在label中。
获取数据:

#因为直接获取到的数据shape为(784,70000),所以要进行转置
#每一个图片的格式都是28*28
data = mnist["data"].T
taget = mnist["label"].T
taget = taget.reshape(70000,)

数据我处理成了一份图片格式的保存在本地,一共七万张图片。后面我也会上传。
代码就不上了,比较简单。需要的也可以私信我。

数据切分和加载算法:
X_train,x_test,Y_train,y_test = train_test_split(data,taget,test_size=0.2)
#使用默认参数
knn = KNeighborsClassifier()
knn.fit(X_train,Y_train)

计算得分

knn.score(x_test,y_test)

输出:0.9715714285714285

这是默认参数的准确率。那么他到底预测了哪些是错的哪些的对的呢?我怎么能看到?
先接收一下他预测出来的数据

y_ = knn.predict(x_test)

好,我们来画个图:

plt.figure(figsize=
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值