代码参考GitHub:lihang-code的第三章。
kdtree用于计算距离多维数据目标点距离最近的点,用于采集K近邻中的k个距离最近的点。
李航书中用一个二维的例子来解释这一过程,由于是二维,距离选择欧式距离(二范数),猜想高维数据将用到更高的范数来计算所谓距离。
题目:
给定一个二维数据集:
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
计算距离点(3,4.5)距离最近的点。
这个例子的优点在于,它的最近点不是分割出这一空间的叶子节点,也不是这个叶子结点的父亲,而是父亲节点的兄弟节点,遍历过程较为复杂,正好可以用于分析程序。
构造KD树
构造部分还是挺好理解的,
- 节点
节点数据、以节点为分割轴的轴方向、左右节点。 - 构造函数
维度循环(0~k),子节点为父亲节点分割维数的加一取余。
数据取当前数据的中位数,当然,需要先排序。
class KdNode(object):
def __init__(self, dom_elt, split, left, right):
self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)
self.split = split # 整数(进行分割维度的序号)
self.left = left # 该结点分割超平面左子空间构成的kd-tree
self.right = right # 该结点分割超平面右子空间构成的kd-tree
class KdTree(object):
def __init__(self, data):
k = len(data[0]) # 数据维度
def CreateNode(split, data_set): # 按第split维划分数据集exset创建KdNode
if not data_set: # 数据集为空
return None
# key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
# operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
#data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
data_set.sort(key=lambda x: x[split])
split_pos = len(data_set) // 2 # //为Python中的整数除法
median = data_set[split_pos] # 中位数分割点
split_next = (split + 1) % k # cycle coordinates
# 递归的创建kd树
return KdNode(
median,
split,
CreateNode(split_next, data_set[:split_pos]), # 创建左子树
CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树
self.root = CreateNode(0, data) # 从第0维分量开始构建kd树,返回根节点
# KDTree的前序遍历
def preorder(root):
print(root.dom_elt)
if root.left: # 节点不为空
preorder(root.left)
if root.right:
preorder(root.right)
搜索KD树
# 对构建好的kd树进行搜索,寻找与目标点最近的样本点:
from math import sqrt
from collections import namedtuple
# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple",
"nearest_point nearest_dist nodes_visited")
def find_nearest(tree, point):
k = len(point) # 数据维度
def travel(kd_node, target, max_dist):
if kd_node is None:
return result([0] * k, float("inf"),
0) # python中用float("inf")和float("-inf")表示正负无穷
nodes_visited = 1
s = kd_node.split # 进行分割的维度
pivot = kd_node.dom_elt # 进行分割的“轴”
if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
nearer_node = kd_node.left # 下一个访问节点为左子树根节点
further_node = kd_node.right # 同时记录下右子树
else: # 目标离右子树更近
nearer_node = kd_node.right # 下一个访问节点为右子树根节点
further_node = kd_node.left
temp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域
nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”
dist = temp1.nearest_dist # 更新最近距离
nodes_visited += temp1.nodes_visited
if dist < max_dist:
max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内
temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离
if max_dist < temp_dist: # 判断超球体是否与超平面相交
return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断
#----------------------------------------------------------------------
# 计算目标点与分割点的欧氏距离
temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))
if temp_dist < dist: # 如果“更近”
nearest = pivot # 更新最近点
dist = temp_dist # 更新最近距离
max_dist = dist # 更新超球体半径
# 检查另一个子结点对应的区域是否有更近的点
temp2 = travel(further_node, target, max_dist)
nodes_visited += temp2.nodes_visited
if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离
nearest = temp2.nearest_point # 更新最近点
dist = temp2.nearest_dist # 更新最近距离
return result(nearest, dist, nodes_visited)
return travel(tree.root, point, float("inf")) # 从根节点开始递归
搜索过程是:先向下找到目标点所在的叶子结点空间(无论最后是在叶子节点的左子空间还是右子空间),确定当前的最近距离(目标点与该叶子节点的距离),以此距离画一个超球体。
只要某一分割线与超球体相交(超球体半径<目标点与分割节点在该维度上的坐标值之差),就代表该分割线划分的子空间中的点有可能位于超球体内部(与目标点距离更近),因此需要遍历该分割线上的节点的左右子空间,没错两个空间都要,遍历方法依旧是向下到该空间叶子节点,再往上。
按这个方法向上回溯,因为目标节点处在所有祖先节点的子空间中,超球体与所有祖先的分割线都有可能相交,就有可能成为最近节点。
溯回到根,终于可以结束程序。
- 找到目标点对应的叶子节点
target[s] 目标点在该子空间的维度上的值
pivot[s] 空间节点在该子空间的维度上的值
每个节点空间都是有维度的,例如根节点对应整个空间,对应split = 0
大于:左子树;小于;右子树。
- 计算超球体半径(当前最近距离),判断超球体是否与该节点所在分割线相交,如相交,更新最小半径等,进入该叶子节点另一子空间(None)。
temp_dist = abs(pivot[s] - target[s])#该维度上坐标数值之差
max_dist < temp_dist#与超球体半径比较
与[4,7]轴相交
计算与[4,7]距离为2.69,并设为最近距离,最近节点。
前面提到,与轴相交,那么轴的两个子空间都要遍历,只不过对于叶子节点,另一子空间为空,就返回了。
- 回溯,重复上述过程
现在处于[5,4]子空间,发现维度数值差为0.5,肯定相交,计算与该节点距离为2.06,更新最短距离。
进入另一子空间,向下查找发现[2,3],重复上述操作,最短距离更新为1.08。
[5,4]子空间被遍历干净,回溯到[7,2],7-3>1.08,直接返回,最终得到一个元祖。