文章目录
KNN 的搜索算法:如何高效找到 K 个最近邻?
关于KNN算法的核心概念可参考这篇文章:文章链接
在 KNN 算法中,“寻找 K 个最近邻” 是核心操作。当数据集规模较小或特征维度较低时,直接计算所有样本距离的暴力搜索法即可满足需求。但在高维空间或大规模数据中,需借助更高效的搜索算法。以下是 4 种主流搜索策略及其适用场景:
一、 暴力搜索(Brute Force)
-
原理:对每个测试样本,逐一计算与所有训练样本的距离,按距离排序后取前 K 小。
-
数学实现:
设训练集大小为 m m m,特征维度为 n n n,单次查询时间复杂度为 O ( m ⋅ n ) O(m \cdot n) O(m⋅n)。
-
示例:
假设有 8 个训练样本(红点),测试样本为蓝点,目标是找到 K=3 个最近邻。算法会计算蓝点到所有 8 个红点的距离,最终选择距离最近的 3 个红点作为邻居。

绘图代码:
import matplotlib.pyplot as plt # 导入matplotlib的pyplot模块,用于数据可视化
import numpy as np # 导入numpy库,用于数值计算和处理数组
from sklearn.neighbors import NearestNeighbors # 从scikit-learn库中导入NearestNeighbors类,用于实现最近邻算法
# 解决中文乱码
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置matplotlib的字体为黑体,以正确显示中文
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题,确保坐标轴负号正常显示
# 定义训练样本(8个二维点)
train_samples = np.array([
[1, 2], [2, 3], [3, 1], [6, 7],
[8, 9], [4, 4], [5, 2], [7, 5]
]) # 创建一个二维numpy数组,包含8个二维数据点,作为训练样本数据
# 定义测试样本
test_sample = np.array([[5, 5]]) # 创建一个二维numpy数组,包含1个二维数据点,作为测试样本数据
# 设置最近邻数量
k = 3 # 设置KNN算法中邻居的数量为3,即寻找距离测试样本最近的3个邻居
# 构建 KNN 模型并查询最近邻
nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(train_samples) # 创建NearestNeighbors对象,
# 设置邻居数量为k,算法自动选择合适的搜索算法(如数据量小用暴力搜索,数据量大时可能选择KD树等),
# 并使用训练样本数据进行模型拟合
distances, indices = nbrs.kneighbors(test_sample) # 使用拟合好的模型,对测试样本进行查询,
# 分别返回测试样本到最近邻居的距离(distances)和最近邻居在训练样本中的索引(indices)
# 绘制训练样本和测试样本
plt.scatter(train_samples[:, 0], train_samples[:, 1], c='red', label='训练样本') # 绘制散点图,
# 使用训练样本的第一列数据作为x轴坐标,第二列数据作为y轴坐标,点的颜色为红色,标签为'训练样本'
plt.scatter(test_sample[0, 0], test_sample[0, 1], c='blue', label='测试样本', s=100, edgecolors='k') # 绘制散点图,
# 使用测试样本的坐标作为点的位置,点的颜色为蓝色,标签为'测试样本',点的大小为100,点的边缘颜色为黑色
# 为每个训练样本添加坐标标注
for i, point in enumerate(train_samples): # 遍历训练样本数组,同时获取每个样本的索引i和数据point
plt.text(point[0] + 0.1, point[1] + 0.1, f"({
point[0]}, {
point[1]})", fontsize=10, color='black') # 在每个训练样本点的旁边添加坐标标注,
# 坐标位置在原坐标基础上稍微偏移(+0.1),标注内容为样本的坐标,字体大小为10,颜色为黑色
# 绘制最近邻连线
for idx in indices[0]: # 遍历测试样本的最近邻居在训练样本中的索引列表(因为只有1个测试样本,所以取indices[0])
plt.plot([test_sample[0, 0], train_samples[idx, 0]], # 绘制线条,起点为测试样本的x坐标,终点为最近邻居样本的x坐标
[test_sample[0, 1], train_samples[idx, 1]], # 绘制线条,起点为测试样本的y坐标,终点为最近邻居样本的y坐标
'k--', lw=1) # 线条颜色为黑色(k),线条样式为虚线(--),线条宽度为1
# 添加图例、标题、轴标签等
plt.legend() # 添加图例,用于标识不同颜色数据点的含义
plt.title('K-最近邻 (K=3)') # 添加图表标题,显示当前使用的KNN算法中K的值为3
plt.xlabel('特征 1') # 添加x轴标签,说明x轴代表的特征含义
plt.ylabel('特征 2') # 添加y轴标签,说明y轴代表的特征含义
plt.grid(True) # 显示网格线,帮助观察数据分布
plt.axis('equal') # 保持纵横坐标轴比例一致,避免图形变形
plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域
plt.show() # 显示绘制好的图形
-
优点:
- 简单直观,无需预处理,适合小规模数据(如 m < 1 0 4 m < 10^4 m<104)。
-
缺点:
- 高维场景下效率极低(维度诅咒),数据量增大时计算量爆炸。
-
sklearn 配置:
algorithm='brute'(默认值之一)。
二、 KD 树(KD-Tree)
-
原理
KNN算法4种主流搜索策略及适用场景

最低0.47元/天 解锁文章
2213

被折叠的 条评论
为什么被折叠?



