一、 KNN 算法简介
KNN 是一种简单、直观的监督学习算法,其核心思想是 “物以类聚”。它通过找出训练集中与测试样本最相似的 K 个邻居,并根据它们的标签进行预测。既可以用于分类任务,也可以用于回归任务。
二、KNN 算法详解
在明确 KNN 算法的核心思想和应用流程后,我们通过数学公式和具体案例进一步拆解其实现逻辑。
2.1 核心决策公式:
分类任务(多数投票法)
假设测试样本的 K 个邻居标签集合为 { y 1 , y 2 , … , y K } \{y_1, y_2, \dots, y_K\} {y1,y2,…,yK},预测类别 y ^ \hat{y} y^ 为出现频率最高的类别:
y ^ = arg max c ∈ Y ∑ i = 1 K I ( y i = c ) \hat{y} = \arg\max_{c \in \mathcal{Y}} \sum_{i=1}^{K} \mathbb{I}(y_i = c) y^=argmaxc∈Y∑i=1KI(yi=c)
其中 I ( ⋅ ) \mathbb{I}(\cdot) I(⋅) 为指示函数,条件成立时返回 1,否则返回 0。
回归任务(均值预测法)
假设 K 个邻居的目标值为 { t 1 , t 2 , … , t K } \{t_1, t_2, \dots, t_K\} {t1,t2,…,tK},预测值 t ^ \hat{t} t^ 为目标值的算术平均:
t ^ = 1 K ∑ i = 1 K t i \hat{t} = \frac{1}{K} \sum_{i=1}^{K} t_i t^=K1∑i=1Kti
2.2 分类任务数学示例:
问题设定:二维特征空间中,训练集包含 5 个样本:
-
红色类别: ( 1 , 1 ) (1,1) (1,1), ( 2 , 2 ) (2,2) (2,2), ( 3 , 1 ) (3,1) (3,1)
-
蓝色类别: ( 4 , 5 ) (4,5) (4,5), ( 5 , 4 ) (5,4) (5,4)
测试样本为 ( 2 , 3 ) (2,3) (2,3),取 K = 3 K=3 K=3,使用欧式距离。
计算步骤:
- 计算距离(欧式距离公式 d = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 d = \sqrt{(x_1-x_2)^2 + (y_1-y_2)^2} d=(x1−x2)2+(y1−y2)2):
-
到 ( 1 , 1 ) (1,1) (1,1): ( 2 − 1 ) 2 + ( 3 − 1 ) 2 = 5 ≈ 2.24 \sqrt{(2-1)^2 + (3-1)^2} = \sqrt{5} \approx 2.24 (2−1)2+(3−1)2=5≈2.24
-
到 ( 2 , 2 ) (2,2) (2,2): ( 2 − 2 ) 2 + ( 3 − 2 ) 2 = 1 \sqrt{(2-2)^2 + (3-2)^2} = 1 (2−2)2+(3−2)2=1
-
到 ( 3 , 1 ) (3,1) (3,1): ( 2 − 3 ) 2 + ( 3 − 1 ) 2 = 5 ≈ 2.24 \sqrt{(2-3)^2 + (3-1)^2} = \sqrt{5} \approx 2.24 (2−3)2+(3−1)2=5≈2.24
-
到 ( 4 , 5 ) (4,5) (4,5): ( 2 − 4 ) 2 + ( 3 − 5 ) 2 = 8 ≈ 2.83 \sqrt{(2-4)^2 + (3-5)^2} = \sqrt{8} \approx 2.83 (2−4)2+(3−5)2=8≈2.83
-
到 ( 5 , 4 ) (5,4) (5,4): ( 2 − 5 ) 2 + ( 3 − 4 ) 2 = 10 ≈ 3.16 \sqrt{(2-5)^2 + (3-4)^2} = \sqrt{10} \approx 3.16 (2−5)2+(3−4)2=10≈3.16
-
排序并选取 K=3 个最近邻居:
距离排序:距离排序: ( 2 , 2 ) (2,2) (2,2)(1)< ( 1 , 1 ) (1,1) (1,1)(2.24)= ( 3 , 1 ) (3,1) (3,1)(2.24)< ( 4 , 5 ) (4,5) (4,5)(2.83)< ( 5 , 4 ) (5,4) (5,4)(3.16)
前 3 个邻居为:前 3 个邻居为: ( 2 , 2 ) (2,2) (2,2)(红)、 ( 1 , 1 ) (1,1) (1,1)(红)、 ( 3 , 1 ) (3,1) (3,1)(红)
-
多数投票:3 个邻居均为红色,预测结果为红色。
2.3 分类任务数学示例:
问题设定:一维特征空间,训练集样本为:
-
( 1 , 2 ) (1, 2) (1,2), ( 3 , 4 ) (3, 4) (3,4), ( 5 , 6 ) (5, 6) (5,6), ( 7 , 8 ) (7, 8) (7,8)
测试样本为 x = 4 x=4 x=4,取 K = 2 K=2 K=2,使用曼哈顿距离(绝对值差)。
计算步骤:
- 计算距离(曼哈顿距离公式 d = ∣ x 1 − x 2 ∣ d = |x_1 - x_2| d=∣x1−x2∣):
-
到 ( 1 , 2 ) (1,2) (1,2): ∣ 4 − 1 ∣ = 3 |4-1|=3 ∣4−1∣=3
-
到 ( 3 , 4 ) (3,4) (3,4): ∣ 4 − 3 ∣ = 1 |4-3|=1 ∣4−3∣=1
-
到 ( 5 , 6 ) (5,6) (5,6): ∣ 4 − 5 ∣ = 1 |4-5|=1 ∣4−5∣=1
-
到 ( 7 , 8 ) (7,8) (7,8): ∣ 4 − 7 ∣ = 3 |4-7|=3 ∣4−7∣=3
-
排序并选取 K=2 个最近邻居:
距离排序:距离排序: ( 3 , 4 ) (3,4) (3,4)(1)= ( 5 , 6 ) (5,6) (5,6)(1)< ( 1 , 2 ) (1,2) (1,2)(3)= ( 7 , 8 ) (7,8) (7,8)(3)
前 2 个邻居为:前 2 个邻居为: ( 3 , 4 ) (3,4) (3,4)、 ( 5 , 6 ) (5,6) (5,6)
-
均值预测:目标值平均为 ( 4 + 6 ) / 2 = 5 (4+6)/2 = 5 (4+6)/2=5,预测结果为 5。
KNN常见的搜索算法:网页连接
三、 KNN 算法的 API
- 3.1 KNN 分类 API:
from sklearn.neighbors import KNeighborsClassifier
# 创建KNN分类器对象,n_neighbors为K值,默认为5
knn_classifier = KNeighborsClassifier(n_neighbors=5)
- 3.2KNN 回归 API:
from sklearn.neighbors import KNeighborsRegressor
# 创建KNN回归器对象,n_neighbors为K值,默认为5
knn_regressor = KNeighborsRegressor(n_neighbors=5)
四、 KNN 算法的应用方式
-
准备数据:给定训练样本
X_train
(特征)、y_train
(标签),以及待预测样本X_test
。 -
计算距离:对每个
X_test
中的样本,计算它与所有X_train
样本之间的距离(如欧氏距离、曼哈顿距离等)。 -
排序距离:将训练集中样本按与测试样本的距离从小到大排序。
-
选取最近的 K 个邻居:取出前 k 个距离最近的训练样本。
-
投票决策(分类)或取平均(回归):
-
分类任务:统计这 k 个样本中出现次数最多的类别,作为预测结果。
-
回归任务:取这 k 个样本目标值的平均值或加权平均值作为预测值。
以下是一个 KNN 分类的示例代码:
# 导入所需的库和模块
from sklearn.datasets import load_iris # 从sklearn导入鸢尾花数据集加载工具
from sklearn.model_selection import train_test_split # 导入数据集划分工具
from sklearn.neighbors import KNeighborsClassifier # 导入K近邻分类器
from sklearn.metrics import accuracy_score # 导入准确率计算工具
# 加载鸢尾花数据集
iris = load_iris() # 加载内置的鸢尾花数据集
X = iris.data # 获取特征数据(花瓣/萼片的长度宽度)
y = iris.target # 获取目标标签(鸢尾花的类别:0,1,2)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2, # 测试集占比20%
random_state=44 # 随机种子确保结果可复现
)
# 创建K近邻分类器实例
knn = KNeighborsClassifier(
n_neighbors=5 # 设置K值(最近邻数量)为5
)
# 训练模型
knn.fit(X_train, y_train) # 使用训练集特征和标签训练分类器
# 进行预测
y_pred = knn.predict(X_test) # 使用训练好的模型对测试集进行预测
# 评估模型性能
accuracy = accuracy_score(y_test, y_pred) # 计算预测结果与真实标签的准确率
print("Accuracy: ", accuracy) # 输出模型准确率(范围0-1,值越高性能越好)
# 补充说明:
# 1. `load_iris()` 数据集包含150个样本,4个特征,3个类别
# 2. `train_test_split` 的random_state参数确保每次划分结果相同
# 3. KNN原理:根据测试样本在特征空间中最近的K个训练样本的类别投票决定分类
# 4. 典型应用场景:简单分类任务、基准模型测试
五、 总结
KNN 算法虽然简单,但在很多实际场景中都能发挥出不错的效果。通过合理选择距离计算方法、进行特征预处理以及利用交叉验证和网格搜索进行参数调优,可以进一步提高 KNN 模型的性能。在实际应用中,我们需要根据数据的特点和任务的需求,灵活运用这些方法和技巧,以达到最佳的效果。希望通过本文的介绍,大家对 KNN 算法有了更深入的理解和认识。