python实现kd树以及最近邻查找算法
一、kd树简介
kd树是一种树形结构,树的每个节点存放一个k维数据,某一节点的子节点可以看作是由过该节点一个平面切割后产生的(想象一下切蛋糕的过程),如此反复产生切割平面,就能为每个数据在空间中建立索引,如下图所示:

由于采用这种特殊的分割方式,使得在利用kd树做最近邻查找时,可以避开一些距离很远的点,查找速度得到了较大的提升,对于空间中N个k维数据,穷举法的算法复杂度为O(Nk),而使用kd树查找的算法复杂度只有O(klog(N))。kd树是一种典型的空间换时间的方式,即花费存储空间为数据建立索引,这样使得后续查找时速度更快,花费时间更少。
二、kd树生成
具体的算法实现主要参考的是这篇文章:https://www.cnblogs.com/eyeszjwang/articles/2429382.html,实现时有少量改动。生成kd树有两个关键的中间过程,即:
1.确定切分域
(1)确定split域:对于所有描述子数据(特征矢量),统计它们在每个维上的数据方差。以SURF特征为例,描述子为64维,可计算64个方差。挑选出最大值,对应的维就是split域的值。数据方差大表明沿该坐标轴方向上的数据分散得比较开,在这个方向上进行数据分割有较好的分辨率;
这段文字用通俗一点的语言来说就是:对于二维的情况,每一次做数据切分的时候,沿着x轴还是y轴做切分是一个问题,那么我们要怎么确定呢?我们可以统计这些二维数据的x值和y值的方差,方差越大说明数据在这一方向上越离散,而数据越离散说明沿着这一方向上数据之间的距离区分度越大,简单点来说就是相互之间隔得更远,我们就用这个方向做切分。
确定了切分域之后,我们就需要来对数据做切分了。
2.确定数据域
(2)确定Node-data域:数据点集Data-set按其第split域的值排序。位于正中间的那个数据点被选为Node-data。此时新的Data-set’ = Data-set\Node-data(除去其中Node-data这一点)。
简单来说,这句话的意思是:现在我们已经确定了沿着x轴做切分,那么我们要怎么决定在x轴哪里做切分呢?我们可以将所有数据根据x值的大小做一个排序,然后选取正中间那个数据的x值作为切分的位置。注意,这里有一个关键的问题是:如果我们有偶数个数据,怎么确定中间那个数据?难道我们选取中间两个数做一下平均???如果没有记错的话这应该是中位数的定义。。。如果这样完全就是自找麻烦!因为我们要确保至少有一个数据的x值落在切分点上,但是取平均之后并不能保证!!!所以更好的办法是,在有两个中间数据的情况下,随便选取一个数据的x值就行了。
决定了在x轴哪里做切分之后,我们就需要把数据做切分了,这里根据数据的x值相对于切分位置的大小,可以归为左节点和右节点,同时不要忘了:当前主节点也要保存一个数据,选取一个x值大小和切分位置相等的数据保存就行(如果有多个随便选一个就行,关键之处在于这个数据的x值落在切割线上。)
3.理解递归树
前面提到过,kd树是一种树形结构,因此可以递归生成,这是树形结构的共性,用程序语言来说,递归就是函数自己调用自己,在理解上也是很自然的。对于一组数据,我们通过找到的一个切分线把数据一分为二,而这个切分线的确定只和这组数据有关,左边的数据归为左节点,右边的数据归为右节点,更进一步,对于左边或者右边的这组数据,我们又可以将其看作一个整体,找到一个切分线把它一分为二,这样将一组数据一分为二的过程反复进行,相当于这个过程函数不断地调用自身,最终生成二叉树,将所有的数据分开。
4.python实现递归树代码
###建立kd树和实现查询功能
import numpy as np
import matplotlib.pyplot as plt
class kdTree:
def __init__(self, parent_node):
'''
节点初始化
'''
self.nodedata = None ###当前节点的数据值,二维数据
self.split = None ###分割平面的方向轴序号,0代表沿着x轴分割,1代表沿着y轴分割
self.range = None ###分割临界值
self.left = None ###左子树节点
self.right = None ###右子树节点
self.parent = parent_node ###父节点
self.leftdata = None ###保留左边节点的所有数据
self.rightdata = None ###保留右边节点的所有数据
self.isinvted = False ###记录当前节点是否被访问过
def print(self):
'''
打印当前节点信息
'''
print(self.nodedata, self.split, self.range)
def getSplitAxis(self, all_data):
'''
根据方差决定分割轴
'''
var_all_data = np.var(all_data, axis=0)
if var_all_data[0] > var_all_data[1]:
return 0
else:
return 1
def getRange(self, split_axis, all_data):
'''
获取对应分割轴上的中位数据值大小
'''
split_all_data = all_data[:, split_axis]
data_count = split_all_data.shape[0]
med_index = int(data_count/2)
sort_split_all_data = np.sort(split_all_data)
range_data = sort_split_all_data[med_index]
return range_data
def getNodeLeftRigthData(self, all_data):
'''
将数据划分到左子树,右子树以及得到当前节点
'''
data_count = all_data.shape[0

最低0.47元/天 解锁文章
771

被折叠的 条评论
为什么被折叠?



