KDTree实现KNN算法
完整的实验代码在我的github上👉QYHcrossover/ML-numpy: 机器学习算法numpy实现 (github.com) 欢迎star⭐
在之前的博客中,我们已经学习了KNN算法的原理和代码实现。KNN算法通过计算待分类样本点和已知样本点之间的距离,选取距离最近的K个点,通过多数表决的方式进行分类。但是,当样本数据量很大时,计算所有样本之间的距离会变得非常耗时,因此我们需要一种更高效的方法来解决这个问题。
KDTree介绍
KDTree是一种常见的数据结构,可以用于高效地查找多维空间中的最近邻点。在KDTree中,每个节点都是一个k维点,节点可以分为左右子树,子树中的节点代表k维空间中的点集。建立KDTree的过程可以通过递归来实现,对于每个节点,我们需要选择一个维度和一个分割值,将该节点的点集按照这个维度的值分为两部分,分别放到左右子树中。分割值可以选取中位数或者其他的分位数,这样可以保证左右子树的平衡,避免树的深度过大,影响查询效率。
基于KDTree的KNN代码实现
在代码实现中,定义了一个 TreeNode
类来表示 KD Tree 的节点,每个节点包含了四个属性:data
表示节点对应的数据点,label
表示数据点的标签,fi
表示当前节点所在的维度,fv
表示当前节点所在维度的特征值,以及 left
和 right
表示左右子节点。
class TreeNode:
def __init__(self,data=None,label=None,fi=None,fv=None,left=None,right=None):
self.data = data
self.label = label
self.fi = fi
self.fv = fv
self.left = left
self.right = right
接着定义了 KDTreeKNN
类,其中 __init__
函数接收一个参数 k
,表示 K 近邻算法中的 K 值,即选择最近的 K 个邻居。buildTree
函数是构建 KD Tree 的核心函数,它接收三个参数:X
表示数据集,y
表示数据集对应的标签,以及 depth
表示当前节点所在的深度。在递归过程中,每次选择当前节点所在的维度 fi
,并将数据集按照该维度的特征值排序,选择排序后中间位置的数据点作为当前节点,然后递归构建左右子树,并返回当前节点。
class KDTreeKNN:
def __init__(self,k=3):
self.k = k
def buildTree(self,X,y,depth):
n_size,n_feature = X.shape
#递归终止条件
if n_size == 1:
tree = TreeNode(data=X[0],label=y[0])
return tree
fi = depth % n_feature
argsort = np.argsort(X[:,fi])
middle_idx = argsort[n_size // 2]
left_idxs,right_idxs = argsort[:n_size//2],argsort[n_size//2+1:]
fv = X[middle_idx,fi]
data,label = X[middle_idx],y[middle_idx]
left,right = None,None
if len(left_idxs) > 0:
left = self.buildTree(X[left_idxs],y[left_idxs],depth+1)
if len(right_idxs) > 0:
right = self.buildTree(X[right_idxs],y[right_idxs],depth+1)
tree = TreeNode(data,label,fi,fv,left,right)
return tree
当我们在KNN算法中找到当前测试