k-近邻算法(kNN算法)

本文详细介绍了k-近邻算法,包括其原理、常用距离指标(欧几里得和曼哈顿距离)、一般流程(数据准备、选择度量、确定k值和预测),并以scikit-learn中的手写数字数据集为例进行实战演示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

一、k-近邻算法了解

1.概述

2.原理

3.kNN算法中常用的距离指标

(1)欧几里得距离(Euclidean Distance)

(2)曼哈顿距离(Manhattan Distance)

二、kNN算法一般流程

1.数据准备

归一化

2.选择距离度量方法

3.确定k值找到k个最近邻居

4.预测

三、kNN算法实例

1.导入所需的库:

2.加载数据集并进行PCA降维,绘制散点图:

3.创建kNN分类器对象并进行训练:

4.预测并分析准确率:


一、k-近邻算法了解

1.概述

K近邻算法(K-Nearest Neighbors,简称KNN)是一种用于分类和回归的统计方法。KNN 可以说是最简单的分类算法之一,同时,它也是最常用的分类算法之一。

2.原理

存在一个样本数据集合,当预测一个新数据时,根据数据集合中与其相邻最近的k个数据中出现次数最多的分类,作为新数据的分类。通常k是不大于20的整数。

3.kNN算法中常用的距离指标

(1)欧几里得距离(Euclidean Distance)

欧几里得距离,简称欧式距离,是我们在平面几何中最常用的距离计算方法,衡量的是多维空间中两个点之间的绝对距离

二维的公式:

d=\sqrt{\left ( x_{1}-x_{2} \right )^{2}+\left (y_{1}-y_{2} \right )^{2}}

推广到n维的公式:

d=\sqrt{\displaystyle\sum_{i=1}^{n} \left ( x_{1i}-x_{2i} \right )^{2}}

(2)曼哈顿距离(Manhattan Distance)

出租车几何或曼哈顿距离,是种使用在几何度量空间的几何学用语,用以标明两个点在标准坐标系上的绝对轴距总和。

二维平面上的曼哈顿距离:

d=\left | x_{i}-x_{j} \right |-\left | y_{i}-y_{j}\right |

n维上的公式:

d=\displaystyle\sum_{i=1}^{n}\left | x_{1i}-x_{2i} \right |


二、kNN算法一般流程

1.数据准备

包括收集、清洗和预处理数据。

对数据预处理以确保所有特征在计算距离时的权重相等。

归一化

预处理通常采用的是归一化的方法,将取值任意的取值范围转化为0到1或者-1到1之间。

将特征映射到[0,1]之间的公式:

x'=\frac{x-min(x)}{max(x)-min(x)}

将特征映射到[-1,1]之间的公式:

x'=\frac{2*(x-min(x))}{max(x)-min(x)}-1

该方法适合用在数值比较集中的情况。

2.选择距离度量方法

确定用于比较样本之间相似性的度量方法,常见的如欧几里得距离、曼哈顿距离等。

3.确定k值找到k个最近邻居

通过交叉验证等方法来选择最优的K值。

计算该样本与训练集中所有样本的距离。

根据距离对它们进行排序。

选择距离最近的K个样本

4.预测

对于分类任务:查看K个最近邻居中最常见的类别,作为预测结果。例如,如果K=3,并且三个最近邻居的类别是[1, 2, 1],那么预测结果就是类别1。

对于回归任务:预测结果可以是K个最近邻居的平均值或加权平均值。


三、kNN算法实例

本文选用scikit-learn库中的手写数字数据集来进行测试。

1.导入所需的库

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

2.加载数据集并进行PCA降维,绘制散点图

hands = fetch_openml(name='optdigits', version=1)
X = hands.data
y = hands.target
pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(X)

plt.figure(figsize=(10, 8))
for i in range(10):
plt.scatter(X_pca[y==str(i), 0], X_pca[y==str(i), 1], label=str(i))
    
plt.title('Optical Recognition of Handwritten Digits')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.show()

3.创建kNN分类器对象并进行训练

k = 5  # 设置k值
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train, y_train)

4.预测并分析准确率

y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

准确率

以上,是kNN算法基本内容已经完成!


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值