关于KD树的介绍,许多博客已经描述的很清楚了,这里就不再叙述,不了解的可以参考https://blog.youkuaiyun.com/app_12062011/article/details/51986805
下面给出代码
"""
构建kd树,提高KNN算法的效率(数据结构要自己做出来才有趣)
1. 使用对象方法封装kd树
2. 每一个结点也用对象表示,结点的相关信息保存在实例属性中
3. 使用递归方式创建树结构以及实现树的其它逻辑结构
"""
import numpy as np
import time
class Node(object):
'''结点对象'''
def __init__(self, item=None, label=None, dim=None, parent=None, left_child=None, right_child=None):
self.item = item # 结点的值(样本信息)
self.label = label # 结点的标签
self.dim = dim # 结点的切分的维度(特征)
self.parent = parent # 父结点
self.left_child = left_child # 左子树
self.right_child = right_child # 右子树
class KDTree(object):
'''kd树'''
def __init__(self, aList, labelList):
self.__length = 0 # 不可修改
self.__root = self.__create(aList,labelList) # 根结点, 私有属性, 不可修改
def __create(self, aList, labelList, parentNode=None):
'''
创建kd树
:param aList: 需要传入一个类数组对象(行数表示样本数,列数表示特征数)
:labellist: 样本的标签
:parentNode: 父结点
:return: 根结点
'''
dataArray = np.array(aList)
m,n = dataArray.shape
labelArray = np.array(labelList).reshape(m,1)
if m == 0: # 样本集为空
return None
# 求所有特征的方差,选择最大的那个特征作为切分超平面
var_list = [np.var(dataArray[:,col]) for col in range(n)] # 获取每一个特征的方差
max_index = var_list.index(max(var_list)) # 获取最大方差特征的索引
# 样本按最大方差特征进行升序排序后,取出位于中间的样本
max_feat_ind_list = dataArray[:,max_index].argsort()
mid_item_index = max_feat_ind_list[m // 2]
if m == 1: # 样本为1时,返回自身
self.__length += 1
return Node(dim=max_index,label=labelArray[mid_item_index], item=dataArray[mid_item_index], parent=parentNode, left_child=None, right_child=None)
# 生成结点
node = Node(dim=max_index, label=labelArray[mid_item_index], item=dataArray[mid_item_index], parent=parentNode, )
# 构建有序的子树
left_tree = dataArray[max_feat_ind_list[:m // 2]] # 左子树
left_label = labelArray[max_feat_ind_list[:m // 2]] # 左子树标签
left_child = self.__create(left_tree,left_label,node)
if m == 2: # 只有左子树,无右子树
right_child = None
else:
right_tree = dataArray[max_feat_ind_list[m // 2 + 1:]] # 右子树
right_label = labelArray[max_feat_ind_list[m // 2 + 1:]] # 右子树标签
right_child = self.__create(right_tree,right_label,node)
# self.__length += 1
# 左右子树递归调用自己,返回子树根结点
node.left_child=left_child
node.right_child=right_child
self.__length += 1
return node
@pro