Kd树讲解推荐文章:https://www.cnblogs.com/bambipai/p/8435797.html
代码实现:
import numpy as np
from sklearn.model_selection import train_test_split
class Node:
def __init__(self, value, lson=None, rson=None):
self.lson = lson
self.rson = rson
self.val = value
class KdTree:
def __init__(self, aixes):
self.tree = None
self.aixes = aixes
self.nearestpoint = None
self.nearestval = 0
self.set = [] #用来保存最近邻点
def creat(self, data, depth): #kd树创建
if len(data) == 0:
return None
sort_val = depth % self.aixes
mid = int(len(data) / 2)
temp_data = data[data[:, sort_val].argsort()]
node = Node(temp_data[mid])
node.lson = self.creat(temp_data[:mid], depth + 1)
node.rson = self.creat(temp_data[mid + 1:], depth + 1)
return node
def computer_dis(self, node_x, node_y):
return ((node_x - node_y) ** 2).sum() ** 0.5
def order(self, node):
if node is None:
return
print(node.val)
self.order(node.lson)
self.order(node.rson)
def check(self, tes):
for obj in self.set:
if (obj == tes).all():
return True
return False
def search(self, pro_data, node, depth):
if node is None:
return
aiex = depth % self.aixes
if pro_data[aiex] < node.val[aiex]:
self.search(pro_data, node.lson, depth + 1)
else:
self.search(pro_data, node.rson, depth + 1)
dis = self.computer_dis(pro_data, node.val)
if self.nearestpoint is None or self.nearestval > dis:
if self.check(node.val) is False: #已经是近邻点,不用在考虑
self.nearestpoint = node.val
self.nearestval = dis
if node.lson != None or node.rson != None:
if abs(pro_data[aiex] - node.val[aiex]) <= self.nearestval:
if pro_data[aiex] > node.val[aiex]:
self.search(pro_data, node.lson, depth + 1)
else:
self.search(pro_data, node.rson, depth + 1)
def main():
train_data = []
train_target = []
with open("iris.csv", "r", encoding="utf-8") as f:
for line in f.readlines():
temp_line = line.replace("\n", "").split(",")
temp_x = []
for element in temp_line[:-1]:
temp_x.append(float(element))
train_data.append(temp_x)
train_target.append(temp_line[-1])
tup_train_data = [tuple(obj) for obj in train_data]
table = dict(zip(tup_train_data, train_target))
train_X, test_X, train_y, test_y = train_test_split(train_data, train_target, random_state=0)
KT = KdTree(len(train_X[0]))
tree = KT.creat(np.array(train_X), 0)
k = 3
out = []
for obj1, obj2 in zip(test_X, test_y):
count = {}
KT.set = []
for i in range(k):
KT.nearestpoint = None
KT.nearestval = 0
KT.search(obj1, tree, 0)
res = KT.nearestpoint
KT.set.append(res)
val = table.get(tuple(res))
if count.get(val) is None:
count[val] = 1
else:
count[val] += 1
out.append(max(count, key=count.get) == obj2)
print("k:",k,sum(out)/len(out))
main()